Skip to content

Commit

Permalink
ml inference ingest processor support for local models (#2508)
Browse files Browse the repository at this point in the history
* ml inference ingest processor support for local models

Signed-off-by: Bhavana Ramaram <[email protected]>
  • Loading branch information
rbhavna authored Jun 11, 2024
1 parent 22b558d commit 7cd5291
Show file tree
Hide file tree
Showing 6 changed files with 1,265 additions and 128 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -1006,7 +1006,10 @@ public void loadExtensions(ExtensionLoader loader) {
public Map<String, org.opensearch.ingest.Processor.Factory> getProcessors(org.opensearch.ingest.Processor.Parameters parameters) {
Map<String, org.opensearch.ingest.Processor.Factory> processors = new HashMap<>();
processors
.put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client));
.put(
MLInferenceIngestProcessor.TYPE,
new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry)
);
return Collections.unmodifiableMap(processors);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -6,25 +6,31 @@

import static org.opensearch.ml.processor.InferenceProcessorAttributes.*;

import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.function.BiConsumer;

import org.apache.logging.log4j.LogManager;
import org.apache.logging.log4j.Logger;
import org.opensearch.action.ActionRequest;
import org.opensearch.action.support.GroupedActionListener;
import org.opensearch.client.Client;
import org.opensearch.core.action.ActionListener;
import org.opensearch.core.common.Strings;
import org.opensearch.core.xcontent.NamedXContentRegistry;
import org.opensearch.ingest.AbstractProcessor;
import org.opensearch.ingest.ConfigurationUtils;
import org.opensearch.ingest.IngestDocument;
import org.opensearch.ingest.Processor;
import org.opensearch.ingest.ValueSource;
import org.opensearch.ml.common.output.model.ModelTensorOutput;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.transport.MLTaskResponse;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
import org.opensearch.ml.common.utils.StringUtils;
Expand All @@ -42,20 +48,31 @@
*/
public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor {

private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class);

public static final String DOT_SYMBOL = ".";
private final InferenceProcessorAttributes inferenceProcessorAttributes;
private final boolean ignoreMissing;
private final String functionName;
private final boolean fullResponsePath;
private final boolean ignoreFailure;
private final boolean override;
private final String modelInput;
private final ScriptService scriptService;
private static Client client;
public static final String TYPE = "ml_inference";
public static final String DEFAULT_OUTPUT_FIELD_NAME = "inference_results";
// allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the
// prediction outcomes, return the whole prediction outcome by skipping filtering
public static final String IGNORE_MISSING = "ignore_missing";
public static final String OVERRIDE = "override";
public static final String FUNCTION_NAME = "function_name";
public static final String FULL_RESPONSE_PATH = "full_response_path";
public static final String MODEL_INPUT = "model_input";
// At default, ml inference processor allows maximum 10 prediction tasks running in parallel
// it can be overwritten using max_prediction_tasks when creating processor
public static final int DEFAULT_MAX_PREDICTION_TASKS = 10;
private final NamedXContentRegistry xContentRegistry;

private Configuration suppressExceptionConfiguration = Configuration
.builder()
Expand All @@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor(
String tag,
String description,
boolean ignoreMissing,
String functionName,
boolean fullResponsePath,
boolean ignoreFailure,
boolean override,
String modelInput,
ScriptService scriptService,
Client client
Client client,
NamedXContentRegistry xContentRegistry
) {
super(tag, description);
this.inferenceProcessorAttributes = new InferenceProcessorAttributes(
Expand All @@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor(
maxPredictionTask
);
this.ignoreMissing = ignoreMissing;
this.functionName = functionName;
this.fullResponsePath = fullResponsePath;
this.ignoreFailure = ignoreFailure;
this.override = override;
this.modelInput = modelInput;
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand Down Expand Up @@ -162,10 +189,48 @@ private void processPredictions(
List<Map<String, String>> processOutputMap,
int inputMapIndex,
int inputMapSize
) {
) throws IOException {
Map<String, String> modelParameters = new HashMap<>();
Map<String, String> modelConfigs = new HashMap<>();

if (inferenceProcessorAttributes.getModelConfigMaps() != null) {
modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps());
modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps());
}

Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());

Map<String, List<String>> newOutputMapping = new HashMap<>();
if (processOutputMap != null) {

Map<String, String> outputMapping = processOutputMap.get(inputMapIndex);
for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
newOutputMapping.put(newDocumentFieldName, dotPathsInArray);
}

for (Map.Entry<String, String> entry : outputMapping.entrySet()) {
String newDocumentFieldName = entry.getKey();
List<String> dotPaths = newOutputMapping.get(newDocumentFieldName);

int existingFields = 0;
for (String path : dotPaths) {
if (ingestDocument.hasField(path)) {
existingFields++;
}
}
if (!override && existingFields == dotPaths.size()) {
logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName);
newOutputMapping.remove(newDocumentFieldName);
}
}
if (newOutputMapping.size() == 0) {
batchPredictionListener.onResponse(null);
return;
}
}
// when no input mapping is provided, default to read all fields from documents as model input
if (inputMapSize == 0) {
Expand All @@ -184,15 +249,30 @@ private void processPredictions(
}
}

ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId());
Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
inputMapKeys.removeAll(modelConfigs.keySet());

Map<String, String> inputMappings = new HashMap<>();
for (String k : inputMapKeys) {
inputMappings.put(k, modelParameters.get(k));
}
ActionRequest request = getMLModelInferenceRequest(
xContentRegistry,
modelParameters,
modelConfigs,
inputMappings,
inferenceProcessorAttributes.getModelId(),
functionName,
modelInput
);

client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() {

@Override
public void onResponse(MLTaskResponse mlTaskResponse) {
ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput();
MLOutput mlOutput = mlTaskResponse.getOutput();
if (processOutputMap == null || processOutputMap.isEmpty()) {
appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument);
} else {
// outMapping serves as a filter to modelTensorOutput, the fields that are not specified
// in the outputMapping will not write to document
Expand All @@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) {
// document field as key, model field as value
String newDocumentFieldName = entry.getKey();
String modelOutputFieldName = entry.getValue();
if (ingestDocument.hasField(newDocumentFieldName)) {
throw new IllegalArgumentException(
"document already has field name "
+ newDocumentFieldName
+ ". Not allow to overwrite the same field name, please check output_map."
);
if (!newOutputMapping.containsKey(newDocumentFieldName)) {
continue;
}
appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument);
}
}
batchPredictionListener.onResponse(null);
Expand Down Expand Up @@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN
/**
* Appends the model output value to the specified field in the IngestDocument without modifying the source.
*
* @param modelTensorOutput the ModelTensorOutput containing the model output
* @param mlOutput the MLOutput containing the model output
* @param modelOutputFieldName the name of the field in the model output
* @param newDocumentFieldName the name of the field in the IngestDocument to append the value to
* @param ingestDocument the IngestDocument to append the value to
*/
private void appendFieldValue(
ModelTensorOutput modelTensorOutput,
MLOutput mlOutput,
String modelOutputFieldName,
String newDocumentFieldName,
IngestDocument ingestDocument
) {
Object modelOutputValue = null;

if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) {
if (mlOutput == null) {
throw new RuntimeException("model inference output is null");
}

modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing);
Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath);

Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);
Map<String, Object> ingestDocumentSourceAndMetaData = new HashMap<>();
ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata());
ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata());
List<String> dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName);

if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
if (dotPathsInArray.size() == 1) {
ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
} else {
if (!(modelOutputValue instanceof List)) {
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
}
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
// check length of the prediction array to be the same of the document array
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
throw new RuntimeException(
"the prediction field: "
+ modelOutputFieldName
+ " is an array in size of "
+ modelOutputValueArray.size()
+ " but the document field array from field "
+ newDocumentFieldName
+ " is in size of "
+ dotPathsInArray.size()
);
}
// Iterate over dotPathInArray
for (int i = 0; i < dotPathsInArray.size(); i++) {
String dotPathInArray = dotPathsInArray.get(i);
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService);
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
} else {
if (!(modelOutputValue instanceof List)) {
throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents.");
}
List<?> modelOutputValueArray = (List<?>) modelOutputValue;
// check length of the prediction array to be the same of the document array
if (dotPathsInArray.size() != modelOutputValueArray.size()) {
throw new RuntimeException(
"the prediction field: "
+ modelOutputFieldName
+ " is an array in size of "
+ modelOutputValueArray.size()
+ " but the document field array from field "
+ newDocumentFieldName
+ " is in size of "
+ dotPathsInArray.size()
);
}
// Iterate over dotPathInArray
for (int i = 0; i < dotPathsInArray.size(); i++) {
String dotPathInArray = dotPathsInArray.get(i);
Object modelOutputValueInArray = modelOutputValueArray.get(i);
ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService);
TemplateScript.Factory ingestField = ConfigurationUtils
.compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService);
ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing);
}
}
} else {
throw new RuntimeException("model inference output cannot be null");
}
}

Expand All @@ -374,16 +448,18 @@ public static class Factory implements Processor.Factory {

private final ScriptService scriptService;
private final Client client;
private final NamedXContentRegistry xContentRegistry;

/**
* Constructs a new instance of the Factory class.
*
* @param scriptService the ScriptService instance to be used by the Factory
* @param client the Client instance to be used by the Factory
*/
public Factory(ScriptService scriptService, Client client) {
public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) {
this.scriptService = scriptService;
this.client = client;
this.xContentRegistry = xContentRegistry;
}

/**
Expand All @@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create(
int maxPredictionTask = ConfigurationUtils
.readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS);
boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false);
boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false);
String functionName = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name());
String modelInput = ConfigurationUtils
.readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }");
boolean defaultValue = !functionName.equalsIgnoreCase("remote");
boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue);

boolean ignoreFailure = ConfigurationUtils
.readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false);
// convert model config user input data structure to Map<String, String>
Expand Down Expand Up @@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create(
processorTag,
description,
ignoreMissing,
functionName,
fullResponsePath,
ignoreFailure,
override,
modelInput,
scriptService,
client
client,
xContentRegistry
);
}
}
Expand Down
Loading

0 comments on commit 7cd5291

Please sign in to comment.