Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Backport to 2.17] add rate limiting for offline batch jobs, set default bulk size to 50… #3122

Merged
merged 1 commit into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ public interface Ingestable {
* @param mlBatchIngestionInput batch ingestion input data
* @return successRate (0 - 100)
*/
default double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
default double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
throw new IllegalStateException("Ingest is not implemented");
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ public OpenAIDataIngestion(Client client) {
}

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
List<String> sources = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
if (Objects.isNull(sources) || sources.isEmpty()) {
return 100;
Expand All @@ -48,13 +48,19 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
boolean isSoleSource = sources.size() == 1;
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
for (int sourceIndex = 0; sourceIndex < sources.size(); sourceIndex++) {
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
successRates.add(ingestSingleSource(sources.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
}

return calculateSuccessRate(successRates);
}

private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIngestionInput, int sourceIndex, boolean isSoleSource) {
private double ingestSingleSource(
String fileId,
MLBatchIngestionInput mlBatchIngestionInput,
int sourceIndex,
boolean isSoleSource,
int bulkSize
) {
double successRate = 0;
try {
String apiKey = mlBatchIngestionInput.getCredential().get(API_KEY);
Expand Down Expand Up @@ -82,8 +88,8 @@ private double ingestSingleSource(String fileId, MLBatchIngestionInput mlBatchIn
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount % 100 == 0) {
// Process every bulkSize lines
if (lineCount % bulkSize == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ public S3DataIngestion(Client client) {
}

@Override
public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
public double ingest(MLBatchIngestionInput mlBatchIngestionInput, int bulkSize) {
S3Client s3 = initS3Client(mlBatchIngestionInput);

List<String> s3Uris = (List<String>) mlBatchIngestionInput.getDataSources().get(SOURCE);
Expand All @@ -63,7 +63,7 @@ public double ingest(MLBatchIngestionInput mlBatchIngestionInput) {
boolean isSoleSource = s3Uris.size() == 1;
List<Double> successRates = Collections.synchronizedList(new ArrayList<>());
for (int sourceIndex = 0; sourceIndex < s3Uris.size(); sourceIndex++) {
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource));
successRates.add(ingestSingleSource(s3, s3Uris.get(sourceIndex), mlBatchIngestionInput, sourceIndex, isSoleSource, bulkSize));
}

return calculateSuccessRate(successRates);
Expand All @@ -74,7 +74,8 @@ public double ingestSingleSource(
String s3Uri,
MLBatchIngestionInput mlBatchIngestionInput,
int sourceIndex,
boolean isSoleSource
boolean isSoleSource,
int bulkSize
) {
String bucketName = getS3BucketName(s3Uri);
String keyName = getS3KeyName(s3Uri);
Expand All @@ -99,8 +100,8 @@ public double ingestSingleSource(
linesBuffer.add(line);
lineCount++;

// Process every 100 lines
if (lineCount % 100 == 0) {
// Process every bulkSize lines
if (lineCount % bulkSize == 0) {
// Create a CompletableFuture that will be completed by the bulkResponseListener
CompletableFuture<Void> future = new CompletableFuture<>();
batchIngest(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.common.MLTaskState.COMPLETED;
import static org.opensearch.ml.common.MLTaskState.FAILED;
import static org.opensearch.ml.plugin.MachineLearningPlugin.INGEST_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE;
import static org.opensearch.ml.task.MLTaskManager.TASK_SEMAPHORE_TIMEOUT;
import static org.opensearch.ml.utils.MLExceptionUtils.OFFLINE_BATCH_INGESTION_DISABLED_ERR_MSG;

Expand All @@ -24,7 +25,9 @@
import org.opensearch.action.support.ActionFilters;
import org.opensearch.action.support.HandledTransportAction;
import org.opensearch.client.Client;
import org.opensearch.cluster.service.ClusterService;
import org.opensearch.common.inject.Inject;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.rest.RestStatus;
import org.opensearch.ml.common.MLTask;
Expand Down Expand Up @@ -60,16 +63,19 @@ public class TransportBatchIngestionAction extends HandledTransportAction<Action
private final Client client;
private ThreadPool threadPool;
private MLFeatureEnabledSetting mlFeatureEnabledSetting;
private volatile Integer batchIngestionBulkSize;

@Inject
public TransportBatchIngestionAction(
ClusterService clusterService,
TransportService transportService,
ActionFilters actionFilters,
Client client,
MLTaskManager mlTaskManager,
ThreadPool threadPool,
MLModelManager mlModelManager,
MLFeatureEnabledSetting mlFeatureEnabledSetting
MLFeatureEnabledSetting mlFeatureEnabledSetting,
Settings settings
) {
super(MLBatchIngestionAction.NAME, transportService, actionFilters, MLBatchIngestionRequest::new);
this.transportService = transportService;
Expand All @@ -78,6 +84,12 @@ public TransportBatchIngestionAction(
this.threadPool = threadPool;
this.mlModelManager = mlModelManager;
this.mlFeatureEnabledSetting = mlFeatureEnabledSetting;

batchIngestionBulkSize = ML_COMMONS_BATCH_INGESTION_BULK_SIZE.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_BATCH_INGESTION_BULK_SIZE, it -> batchIngestionBulkSize = it);

}

@Override
Expand Down Expand Up @@ -131,33 +143,45 @@ protected void createMLTaskandExecute(MLBatchIngestionInput mlBatchIngestionInpu
.state(MLTaskState.CREATED)
.build();

mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error =
"Exceeded maximum limit for BATCH_INGEST tasks. To increase the limit, update the plugins.ml_commons.max_batch_ingestion_tasks setting.";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
} else {
mlTaskManager.createMLTask(mlTask, ActionListener.wrap(response -> {
String taskId = response.getId();
try {
mlTask.setTaskId(taskId);
mlTaskManager.add(mlTask);
listener.onResponse(new MLBatchIngestionResponse(taskId, MLTaskType.BATCH_INGEST, MLTaskState.CREATED.name()));
String ingestType = (String) mlBatchIngestionInput.getDataSources().get(TYPE);
Ingestable ingestable = MLEngineClassLoader.initInstance(ingestType.toLowerCase(), client, Client.class);
threadPool.executor(INGEST_THREAD_POOL).execute(() -> {
executeWithErrorHandling(() -> {
double successRate = ingestable.ingest(mlBatchIngestionInput, batchIngestionBulkSize);
handleSuccessRate(successRate, taskId);
}, taskId);
});
} catch (Exception ex) {
log.error("Failed in batch ingestion", ex);
mlTaskManager
.updateMLTask(
taskId,
Map.of(STATE_FIELD, FAILED, ERROR_FIELD, MLExceptionUtils.getRootCauseMessage(ex)),
TASK_SEMAPHORE_TIMEOUT,
true
);
listener.onFailure(ex);
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
listener.onFailure(exception);
}));
}
}, exception -> {
log.error("Failed to create batch ingestion task", exception);
log.error("Failed to check the maximum BATCH_INGEST Task limits", exception);
listener.onFailure(exception);
}));
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@
import static org.opensearch.ml.engine.utils.FileUtils.deleteFileQuietly;
import static org.opensearch.ml.plugin.MachineLearningPlugin.DEPLOY_THREAD_POOL;
import static org.opensearch.ml.plugin.MachineLearningPlugin.REGISTER_THREAD_POOL;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_MODELS_PER_NODE;
import static org.opensearch.ml.settings.MLCommonsSettings.ML_COMMONS_MAX_REGISTER_MODEL_TASKS_PER_NODE;
Expand Down Expand Up @@ -107,6 +109,7 @@
import org.opensearch.ml.common.MLModelGroup;
import org.opensearch.ml.common.MLTask;
import org.opensearch.ml.common.MLTaskState;
import org.opensearch.ml.common.MLTaskType;
import org.opensearch.ml.common.connector.Connector;
import org.opensearch.ml.common.controller.MLController;
import org.opensearch.ml.common.controller.MLRateLimiter;
Expand Down Expand Up @@ -177,6 +180,8 @@ public class MLModelManager {
private volatile Integer maxModelPerNode;
private volatile Integer maxRegisterTasksPerNode;
private volatile Integer maxDeployTasksPerNode;
private volatile Integer maxBatchInferenceTasks;
private volatile Integer maxBatchIngestionTasks;

public static final ImmutableSet MODEL_DONE_STATES = ImmutableSet
.of(
Expand Down Expand Up @@ -232,6 +237,16 @@ public MLModelManager(
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE, it -> maxDeployTasksPerNode = it);

maxBatchInferenceTasks = ML_COMMONS_MAX_BATCH_INFERENCE_TASKS.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INFERENCE_TASKS, it -> maxBatchInferenceTasks = it);

maxBatchIngestionTasks = ML_COMMONS_MAX_BATCH_INGESTION_TASKS.get(settings);
clusterService
.getClusterSettings()
.addSettingsUpdateConsumer(ML_COMMONS_MAX_BATCH_INGESTION_TASKS, it -> maxBatchIngestionTasks = it);
}

public void registerModelMeta(MLRegisterModelMetaInput mlRegisterModelMetaInput, ActionListener<String> listener) {
Expand Down Expand Up @@ -863,6 +878,18 @@ public void checkAndAddRunningTask(MLTask mlTask, Integer runningTaskLimit) {
mlTaskManager.checkLimitAndAddRunningTask(mlTask, runningTaskLimit);
}

/**
* Check if exceed batch job task limit
*
* @param mlTask ML task
* @param listener ActionListener if the limit is exceeded
*/
public void checkMaxBatchJobTask(MLTask mlTask, ActionListener<Boolean> listener) {
MLTaskType taskType = mlTask.getTaskType();
int maxLimit = taskType.equals(MLTaskType.BATCH_PREDICTION) ? maxBatchInferenceTasks : maxBatchIngestionTasks;
mlTaskManager.checkMaxBatchJobTask(taskType, maxLimit, listener);
}

private void updateModelRegisterStateAsDone(
MLRegisterModelInput registerModelInput,
String taskId,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -973,7 +973,10 @@ public List<Setting<?>> getSettings() {
MLCommonsSettings.ML_COMMONS_REMOTE_JOB_STATUS_EXPIRED_REGEX,
MLCommonsSettings.ML_COMMONS_CONTROLLER_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INGESTION_ENABLED,
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED
MLCommonsSettings.ML_COMMONS_OFFLINE_BATCH_INFERENCE_ENABLED,
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INFERENCE_TASKS,
MLCommonsSettings.ML_COMMONS_MAX_BATCH_INGESTION_TASKS,
MLCommonsSettings.ML_COMMONS_BATCH_INGESTION_BULK_SIZE
);
return settings;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,15 @@ private MLCommonsSettings() {}
Setting.Property.NodeScope,
Setting.Property.Dynamic
);

public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INFERENCE_TASKS = Setting
.intSetting("plugins.ml_commons.max_batch_inference_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_MAX_BATCH_INGESTION_TASKS = Setting
.intSetting("plugins.ml_commons.max_batch_ingestion_tasks", 10, 0, 500, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final Setting<Integer> ML_COMMONS_BATCH_INGESTION_BULK_SIZE = Setting
.intSetting("plugins.ml_commons.batch_ingestion_bulk_size", 500, 100, 100000, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_DEPLOY_MODEL_TASKS_PER_NODE = Setting
.intSetting("plugins.ml_commons.max_deploy_model_tasks_per_node", 10, 0, 10, Setting.Property.NodeScope, Setting.Property.Dynamic);
public static final Setting<Integer> ML_COMMONS_MAX_ML_TASK_PER_NODE = Setting
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -253,6 +253,33 @@ protected void executeTask(MLPredictionTaskRequest request, ActionListener<MLTas
.lastUpdateTime(now)
.async(false)
.build();
if (actionType.equals(ActionType.BATCH_PREDICT)) {
mlModelManager.checkMaxBatchJobTask(mlTask, ActionListener.wrap(exceedLimits -> {
if (exceedLimits) {
String error =
"Exceeded maximum limit for BATCH_PREDICTION tasks. To increase the limit, update the plugins.ml_commons.max_batch_inference_tasks setting.";
log.warn(error + " in task " + mlTask.getTaskId());
listener.onFailure(new OpenSearchStatusException(error, RestStatus.TOO_MANY_REQUESTS));
} else {
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
}
}, exception -> {
log.error("Failed to check the maximum BATCH_PREDICTION Task limits", exception);
listener.onFailure(exception);
}));
return;
}
executePredictionByInputDataType(inputDataType, modelId, mlInput, mlTask, functionName, listener);
}

private void executePredictionByInputDataType(
MLInputDataType inputDataType,
String modelId,
MLInput mlInput,
MLTask mlTask,
FunctionName functionName,
ActionListener<MLTaskResponse> listener
) {
switch (inputDataType) {
case SEARCH_QUERY:
ActionListener<MLInputDataset> dataFrameActionListener = ActionListener.wrap(dataSet -> {
Expand Down
Loading
Loading