From 5c6aa088a40d5fe7123a1f79d8166029935b1ee2 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Wed, 5 Jun 2024 17:02:54 -0500 Subject: [PATCH 1/5] ml inference ingest processor support for local models Signed-off-by: Bhavana Ramaram --- .../ml/plugin/MachineLearningPlugin.java | 5 +- .../InferenceProcessorAttributes.java | 2 +- .../processor/MLInferenceIngestProcessor.java | 195 +++++++++++-- .../ml/processor/ModelExecutor.java | 88 +++++- ...LInferenceIngestProcessorFactoryTests.java | 6 +- .../MLInferenceIngestProcessorTests.java | 275 +++++++++++------- 6 files changed, 420 insertions(+), 151 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index 89b812b613..16209b1353 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1000,7 +1000,10 @@ public void loadExtensions(ExtensionLoader loader) { public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { Map processors = new HashMap<>(); processors - .put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client)); + .put( + MLInferenceIngestProcessor.TYPE, + new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry) + ); return Collections.unmodifiableMap(processors); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java index 9a72d04577..f51322737b 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java @@ -80,4 +80,4 @@ public InferenceProcessorAttributes( this.maxPredictionTask = maxPredictionTask; } -} +} \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index c06f32803c..c782eee4f5 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -6,9 +6,11 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; @@ -19,11 +21,14 @@ import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.ingest.ValueSource; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; @@ -45,7 +50,11 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod public static final String DOT_SYMBOL = "."; private final InferenceProcessorAttributes inferenceProcessorAttributes; private final boolean ignoreMissing; + private final String functionName; + private final boolean fullResponsePath; private final boolean ignoreFailure; + private final boolean override; + private final String modelInput; private final ScriptService scriptService; private static Client client; public static final String TYPE = "ml_inference"; @@ -53,9 +62,14 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the // prediction outcomes, return the whole prediction outcome by skipping filtering public static final String IGNORE_MISSING = "ignore_missing"; + public static final String OVERRIDE = "override"; + public static final String FUNCTION_NAME = "function_name"; + public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; // At default, ml inference processor allows maximum 10 prediction tasks running in parallel // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + private final NamedXContentRegistry xContentRegistry; private Configuration suppressExceptionConfiguration = Configuration .builder() @@ -71,9 +85,14 @@ protected MLInferenceIngestProcessor( String tag, String description, boolean ignoreMissing, + String functionName, + boolean fullResponsePath, boolean ignoreFailure, + boolean override, + String modelInput, ScriptService scriptService, - Client client + Client client, + NamedXContentRegistry xContentRegistry ) { super(tag, description); this.inferenceProcessorAttributes = new InferenceProcessorAttributes( @@ -84,9 +103,14 @@ protected MLInferenceIngestProcessor( maxPredictionTask ); this.ignoreMissing = ignoreMissing; + this.functionName = functionName; + this.fullResponsePath = fullResponsePath; this.ignoreFailure = ignoreFailure; + this.override = override; + this.modelInput = modelInput; this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -162,10 +186,44 @@ private void processPredictions( List> processOutputMap, int inputMapIndex, int inputMapSize - ) { + ) throws IOException { Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + Map outputMapping = processOutputMap.get(inputMapIndex); + + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + + Map> newOutputMapping = new HashMap<>(); + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + newOutputMapping.put(newDocumentFieldName, dotPathsInArray); + } + + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPaths = newOutputMapping.get(newDocumentFieldName); + + int existingFields = 0; + for (String path : dotPaths) { + if (ingestDocument.hasField(path)) { + existingFields++; + } + } + if (!override && existingFields == dotPaths.size()) { + newOutputMapping.remove(newDocumentFieldName); + } + } + if (newOutputMapping.size() == 0) { + batchPredictionListener.onResponse(null); + return; } // when no input mapping is provided, default to read all fields from documents as model input if (inputMapSize == 0) { @@ -184,15 +242,30 @@ private void processPredictions( } } - ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId()); + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + ActionRequest request = getRemoteModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { @Override public void onResponse(MLTaskResponse mlTaskResponse) { - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + MLOutput mlOutput = mlTaskResponse.getOutput(); if (processOutputMap == null || processOutputMap.isEmpty()) { - appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); + appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); } else { // outMapping serves as a filter to modelTensorOutput, the fields that are not specified // in the outputMapping will not write to document @@ -202,14 +275,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) { // document field as key, model field as value String newDocumentFieldName = entry.getKey(); String modelOutputFieldName = entry.getValue(); - if (ingestDocument.hasField(newDocumentFieldName)) { - throw new IllegalArgumentException( - "document already has field name " - + newDocumentFieldName - + ". Not allow to overwrite the same field name, please check output_map." - ); + if (!newOutputMapping.containsKey(newDocumentFieldName)) { + continue; } - appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); + appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); } } batchPredictionListener.onResponse(null); @@ -322,16 +391,16 @@ private void appendFieldValue( modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing); - Map ingestDocumentSourceAndMetaData = new HashMap<>(); - ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); - ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); - List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName); if (dotPathsInArray.size() == 1) { - ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + if (!ingestDocument.hasField(dotPathsInArray.get(0)) || override) { + ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } } else { if (!(modelOutputValue instanceof List)) { throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); @@ -353,11 +422,13 @@ private void appendFieldValue( // Iterate over dotPathInArray for (int i = 0; i < dotPathsInArray.size(); i++) { String dotPathInArray = dotPathsInArray.get(i); - Object modelOutputValueInArray = modelOutputValueArray.get(i); - ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + if (!ingestDocument.hasField(dotPathInArray) || override) { + Object modelOutputValueInArray = modelOutputValueArray.get(i); + ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } } } } else { @@ -365,6 +436,59 @@ private void appendFieldValue( } } + private void appendFieldValue( + MLOutput mlOutput, + String modelOutputFieldName, + String newDocumentFieldName, + IngestDocument ingestDocument + ) { + + if (mlOutput == null) { + throw new RuntimeException("model inference output is null"); + } + + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); + + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + + if (dotPathsInArray.size() == 1) { + ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } else { + if (!(modelOutputValue instanceof List)) { + throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); + } + List modelOutputValueArray = (List) modelOutputValue; + // check length of the prediction array to be the same of the document array + if (dotPathsInArray.size() != modelOutputValueArray.size()) { + throw new RuntimeException( + "the prediction field: " + + modelOutputFieldName + + " is an array in size of " + + modelOutputValueArray.size() + + " but the document field array from field " + + newDocumentFieldName + + " is in size of " + + dotPathsInArray.size() + ); + } + // Iterate over dotPathInArray + for (int i = 0; i < dotPathsInArray.size(); i++) { + String dotPathInArray = dotPathsInArray.get(i); + Object modelOutputValueInArray = modelOutputValueArray.get(i); + ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } + } + } + @Override public String getType() { return TYPE; @@ -374,6 +498,7 @@ public static class Factory implements Processor.Factory { private final ScriptService scriptService; private final Client client; + private final NamedXContentRegistry xContentRegistry; /** * Constructs a new instance of the Factory class. @@ -381,9 +506,10 @@ public static class Factory implements Processor.Factory { * @param scriptService the ScriptService instance to be used by the Factory * @param client the Client instance to be used by the Factory */ - public Factory(ScriptService scriptService, Client client) { + public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) { this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -410,6 +536,14 @@ public MLInferenceIngestProcessor create( int maxPredictionTask = ConfigurationUtils .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + String modelInput = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }"); + boolean defaultValue = !functionName.equals("remote"); + boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue); + boolean ignoreFailure = ConfigurationUtils .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); // convert model config user input data structure to Map @@ -440,11 +574,16 @@ public MLInferenceIngestProcessor create( processorTag, description, ignoreMissing, + functionName, + fullResponsePath, ignoreFailure, + override, + modelInput, scriptService, - client + client, + xContentRegistry ); } } -} +} \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 1abc770d07..c9f9729af1 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -5,17 +5,29 @@ package org.opensearch.ml.processor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -45,17 +57,57 @@ public interface ModelExecutor { * @return an ActionRequest instance for remote model inference * @throws IllegalArgumentException if the input parameters are null */ - default ActionRequest getRemoteModelInferenceRequest(Map parameters, String modelId) { + default ActionRequest getRemoteModelInferenceRequest( + NamedXContentRegistry xContentRegistry, + Map parameters, + Map modelConfigs, + Map inputMappings, + String modelId, + String functionNameStr, + String modelInput + ) throws IOException { if (parameters == null) { throw new IllegalArgumentException("wrong input. The model input cannot be empty."); } - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + FunctionName functionName = FunctionName.REMOTE; + if (functionNameStr != null) { + functionName = FunctionName.from(functionNameStr); + } + // RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + Map inputParams = new HashMap<>(); + if (FunctionName.REMOTE == functionName) { + inputParams.put("parameters", StringUtils.toJson(parameters)); + } else { + inputParams.putAll(parameters); + } - ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null); + String payload = modelInput; + // payload = fillNullParameters(parameters, payload); + StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}"); + payload = modelConfigSubstitutor.replace(payload); + StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}"); + payload = inputMapSubstitutor.replace(payload); + StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}"); + payload = parametersSubstitutor.replace(payload); - return request; + if (!isJson(payload)) { + throw new IllegalArgumentException("Invalid payload: " + payload); + } + + // String jsonStr; + // try { + // jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(inputParams)); + // } catch (PrivilegedActionException e) { + // throw new IllegalArgumentException("wrong connector"); + // } + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInput mlInput = MLInput.parse(parser, functionName.name()); + // MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + + return new MLPredictionTaskRequest(modelId, mlInput); } @@ -135,6 +187,28 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m return modelOutputValue; } + default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldName, boolean ignoreMissing, boolean fullResponsePath) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); + Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); + try { + if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { + return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); + } else { + return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); + } + } catch (Exception e) { + if (ignoreMissing) { + return modelTensorOutputMap; + } else { + throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName, e); + } + } + } catch (Exception e) { + throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + } + } + /** * Parses the data from the given ModelTensor and returns it as an Object. * The method handles different data types (integer, floating-point, string, and boolean) @@ -249,4 +323,4 @@ default String convertToDotPath(String path) { return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); } -} +} \ No newline at end of file diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java index 577e8b8693..30c9a31ada 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -15,6 +15,7 @@ import org.mockito.Mock; import org.opensearch.OpenSearchParseException; import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.Processor; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -26,9 +27,12 @@ public class MLInferenceIngestProcessorFactoryTests extends OpenSearchTestCase { @Mock private ScriptService scriptService; + @Mock + private NamedXContentRegistry xContentRegistry; + @Before public void init() { - factory = new MLInferenceIngestProcessor.Factory(scriptService, client); + factory = new MLInferenceIngestProcessor.Factory(scriptService, client, xContentRegistry); } public void testCreateRequiredFields() throws Exception { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index d11cc213de..5f96f23368 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -20,12 +20,14 @@ import org.junit.Assert; import org.junit.Before; +import org.junit.Ignore; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; import org.mockito.MockitoAnnotations; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.IngestDocument; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.MLResultDataType; @@ -52,6 +54,9 @@ public class MLInferenceIngestProcessorTests extends OpenSearchTestCase { private ScriptService scriptService; @Mock private BiConsumer handler; + + @Mock + NamedXContentRegistry xContentRegistry; private static final String PROCESSOR_TAG = "inference"; private static final String DESCRIPTION = "inference_test"; private IngestDocument ingestDocument; @@ -81,7 +86,7 @@ private MLInferenceIngestProcessor createMLInferenceProcessor( boolean ignoreMissing, boolean ignoreFailure ) { - return new MLInferenceIngestProcessor( + return createMLInferenceProcessor2( model_id, input_map, output_map, @@ -90,9 +95,52 @@ private MLInferenceIngestProcessor createMLInferenceProcessor( PROCESSOR_TAG, DESCRIPTION, ignoreMissing, + "remote", + false, ignoreFailure, + false, + null, scriptService, - client + client, + xContentRegistry + ); + } + + private MLInferenceIngestProcessor createMLInferenceProcessor2( + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, + int maxPredictionTask, + String tag, + String description, + boolean ignoreMissing, + String functionName, + boolean fullResponsePath, + boolean ignoreFailure, + boolean override, + String modelInput, + ScriptService scriptService, + Client client, + NamedXContentRegistry xContentRegistry + ) { + return new MLInferenceIngestProcessor( + modelId, + inputMaps, + outputMaps, + modelConfigMaps, + RANDOM_MULTIPLIER, + PROCESSOR_TAG, + DESCRIPTION, + ignoreMissing, + functionName, + fullResponsePath, + ignoreFailure, + override, + modelInput, + scriptService, + client, + xContentRegistry ); } @@ -137,68 +185,69 @@ public void testExecute_nestedObjectStringDocumentSuccess() { * test nested object document with array of Map, * the value Object is a Map */ - public void testExecute_nestedObjectMapDocumentSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); - - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); - ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); - ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); - ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); - actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); - return null; - }).when(client).execute(any(), any(), any()); - - ArrayList childDocuments = new ArrayList<>(); - Map childDocument1Text = new HashMap<>(); - childDocument1Text.put("text", "this is first"); - Map childDocument1 = new HashMap<>(); - childDocument1.put("chunk", childDocument1Text); - - Map childDocument2 = new HashMap<>(); - Map childDocument2Text = new HashMap<>(); - childDocument2Text.put("text", "this is second"); - childDocument2.put("chunk", childDocument2Text); - - childDocuments.add(childDocument1); - childDocuments.add(childDocument2); - - Map sourceAndMetadata = new HashMap<>(); - sourceAndMetadata.put("chunks", childDocuments); - - IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - processor.execute(nestedObjectIngestDocument, handler); - - // match input dataset - - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); - verify(client).execute(any(), argumentCaptor.capture(), any()); - - Map inputParameters = new HashMap<>(); - ArrayList embedding_text = new ArrayList<>(); - embedding_text.add("this is first"); - embedding_text.add("this is second"); - inputParameters.put("inputs", modelExecutor.toString(embedding_text)); - - MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); - MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); - - RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest - .getMlInput() - .getInputDataset(); - RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); - - assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); - - // match document - sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3)))); - IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); - verify(handler).accept(eq(ingestDocument1), isNull()); - assertEquals(nestedObjectIngestDocument, ingestDocument1); - } +// @Ignore +// public void testExecute_nestedObjectMapDocumentSuccess() { +// List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); +// +// MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); +// ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); +// ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); +// ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); +// +// doAnswer(invocation -> { +// ActionListener actionListener = invocation.getArgument(2); +// actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); +// return null; +// }).when(client).execute(any(), any(), any()); +// +// ArrayList childDocuments = new ArrayList<>(); +// Map childDocument1Text = new HashMap<>(); +// childDocument1Text.put("text", "this is first"); +// Map childDocument1 = new HashMap<>(); +// childDocument1.put("chunk", childDocument1Text); +// +// Map childDocument2 = new HashMap<>(); +// Map childDocument2Text = new HashMap<>(); +// childDocument2Text.put("text", "this is second"); +// childDocument2.put("chunk", childDocument2Text); +// +// childDocuments.add(childDocument1); +// childDocuments.add(childDocument2); +// +// Map sourceAndMetadata = new HashMap<>(); +// sourceAndMetadata.put("chunks", childDocuments); +// +// IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); +// processor.execute(nestedObjectIngestDocument, handler); +// +// // match input dataset +// +// ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); +// verify(client).execute(any(), argumentCaptor.capture(), any()); +// +// Map inputParameters = new HashMap<>(); +// ArrayList embedding_text = new ArrayList<>(); +// embedding_text.add("this is first"); +// embedding_text.add("this is second"); +// inputParameters.put("inputs", modelExecutor.toString(embedding_text)); +// +// MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor +// .getRemoteModelInferenceRequest(inputParameters, "model1"); +// MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); +// +// RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest +// .getMlInput() +// .getInputDataset(); +// RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); +// +// assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); +// +// // match document +// sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3)))); +// IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); +// verify(handler).accept(eq(ingestDocument1), isNull()); +// assertEquals(nestedObjectIngestDocument, ingestDocument1); +// } public void testExecute_jsonPathWithMissingLeaves() { @@ -224,55 +273,55 @@ public void testExecute_jsonPathWithMissingLeaves() { * test nested object document with array of Map, * the value Object is a also a nested object, */ - public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); - - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); - ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); - ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); - ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); - - doAnswer(invocation -> { - ActionListener actionListener = invocation.getArgument(2); - actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); - return null; - }).when(client).execute(any(), any(), any()); - - Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); - - IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); - processor.execute(nestedObjectIngestDocument, handler); - - // match input dataset - ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); - verify(client).execute(any(), argumentCaptor.capture(), any()); - - Map inputParameters = new HashMap<>(); - ArrayList embedding_text = new ArrayList<>(); - embedding_text.add("this is first"); - embedding_text.add("this is second"); - embedding_text.add("this is third"); - embedding_text.add("this is fourth"); - inputParameters.put("inputs", modelExecutor.toString(embedding_text)); - - MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); - MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); - - RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest - .getMlInput() - .getInputDataset(); - RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); - - assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); - - // match document - sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4)))); - IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); - verify(handler).accept(eq(ingestDocument1), isNull()); - assertEquals(nestedObjectIngestDocument, ingestDocument1); - - } +// public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { +// List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); +// +// MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); +// ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); +// ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); +// ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); +// +// doAnswer(invocation -> { +// ActionListener actionListener = invocation.getArgument(2); +// actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); +// return null; +// }).when(client).execute(any(), any(), any()); +// +// Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); +// +// IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); +// processor.execute(nestedObjectIngestDocument, handler); +// +// // match input dataset +// ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); +// verify(client).execute(any(), argumentCaptor.capture(), any()); +// +// Map inputParameters = new HashMap<>(); +// ArrayList embedding_text = new ArrayList<>(); +// embedding_text.add("this is first"); +// embedding_text.add("this is second"); +// embedding_text.add("this is third"); +// embedding_text.add("this is fourth"); +// inputParameters.put("inputs", modelExecutor.toString(embedding_text)); +// +// MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor +// .getRemoteModelInferenceRequest(inputParameters, "model1"); +// MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); +// +// RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest +// .getMlInput() +// .getInputDataset(); +// RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); +// +// assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); +// +// // match document +// sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4)))); +// IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); +// verify(handler).accept(eq(ingestDocument1), isNull()); +// assertEquals(nestedObjectIngestDocument, ingestDocument1); +// +// } public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); From a78f03ce28f9fdd743237039771050e3afff2ac1 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Mon, 10 Jun 2024 14:25:27 -0500 Subject: [PATCH 2/5] add unit tests and ITs Signed-off-by: Bhavana Ramaram --- .../InferenceProcessorAttributes.java | 2 +- .../processor/MLInferenceIngestProcessor.java | 103 +-- .../ml/processor/ModelExecutor.java | 37 +- ...LInferenceIngestProcessorFactoryTests.java | 32 +- .../MLInferenceIngestProcessorTests.java | 819 ++++++++++++++---- .../RestMLInferenceIngestProcessorIT.java | 134 ++- 6 files changed, 833 insertions(+), 294 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java index f51322737b..9a72d04577 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/InferenceProcessorAttributes.java @@ -80,4 +80,4 @@ public InferenceProcessorAttributes( this.maxPredictionTask = maxPredictionTask; } -} \ No newline at end of file +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index c782eee4f5..3362d6794e 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -29,7 +29,6 @@ import org.opensearch.ingest.ValueSource; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.output.MLOutput; -import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.utils.StringUtils; @@ -194,37 +193,40 @@ private void processPredictions( modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); } - Map outputMapping = processOutputMap.get(inputMapIndex); Map ingestDocumentSourceAndMetaData = new HashMap<>(); ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); Map> newOutputMapping = new HashMap<>(); - for (Map.Entry entry : outputMapping.entrySet()) { - String newDocumentFieldName = entry.getKey(); - List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); - newOutputMapping.put(newDocumentFieldName, dotPathsInArray); - } + if (processOutputMap != null) { + + Map outputMapping = processOutputMap.get(inputMapIndex); + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + newOutputMapping.put(newDocumentFieldName, dotPathsInArray); + } - for (Map.Entry entry : outputMapping.entrySet()) { - String newDocumentFieldName = entry.getKey(); - List dotPaths = newOutputMapping.get(newDocumentFieldName); + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPaths = newOutputMapping.get(newDocumentFieldName); - int existingFields = 0; - for (String path : dotPaths) { - if (ingestDocument.hasField(path)) { - existingFields++; + int existingFields = 0; + for (String path : dotPaths) { + if (ingestDocument.hasField(path)) { + existingFields++; + } + } + if (!override && existingFields == dotPaths.size()) { + newOutputMapping.remove(newDocumentFieldName); } } - if (!override && existingFields == dotPaths.size()) { - newOutputMapping.remove(newDocumentFieldName); + if (newOutputMapping.size() == 0) { + batchPredictionListener.onResponse(null); + return; } } - if (newOutputMapping.size() == 0) { - batchPredictionListener.onResponse(null); - return; - } // when no input mapping is provided, default to read all fields from documents as model input if (inputMapSize == 0) { Set documentFields = ingestDocument.getSourceAndMetadata().keySet(); @@ -374,68 +376,11 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN /** * Appends the model output value to the specified field in the IngestDocument without modifying the source. * - * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param mlOutput the MLOutput containing the model output * @param modelOutputFieldName the name of the field in the model output * @param newDocumentFieldName the name of the field in the IngestDocument to append the value to * @param ingestDocument the IngestDocument to append the value to */ - private void appendFieldValue( - ModelTensorOutput modelTensorOutput, - String modelOutputFieldName, - String newDocumentFieldName, - IngestDocument ingestDocument - ) { - Object modelOutputValue = null; - - if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) { - - modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing); - - List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocument.getSourceAndMetadata(), newDocumentFieldName); - - if (dotPathsInArray.size() == 1) { - if (!ingestDocument.hasField(dotPathsInArray.get(0)) || override) { - ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); - - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } - } else { - if (!(modelOutputValue instanceof List)) { - throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); - } - List modelOutputValueArray = (List) modelOutputValue; - // check length of the prediction array to be the same of the document array - if (dotPathsInArray.size() != modelOutputValueArray.size()) { - throw new RuntimeException( - "the prediction field: " - + modelOutputFieldName - + " is an array in size of " - + modelOutputValueArray.size() - + " but the document field array from field " - + newDocumentFieldName - + " is in size of " - + dotPathsInArray.size() - ); - } - // Iterate over dotPathInArray - for (int i = 0; i < dotPathsInArray.size(); i++) { - String dotPathInArray = dotPathsInArray.get(i); - if (!ingestDocument.hasField(dotPathInArray) || override) { - Object modelOutputValueInArray = modelOutputValueArray.get(i); - ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } - } - } - } else { - throw new RuntimeException("model inference output cannot be null"); - } - } - private void appendFieldValue( MLOutput mlOutput, String modelOutputFieldName, @@ -586,4 +531,4 @@ public MLInferenceIngestProcessor create( } } -} \ No newline at end of file +} diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index c9f9729af1..57704896af 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -94,13 +94,6 @@ default ActionRequest getRemoteModelInferenceRequest( if (!isJson(payload)) { throw new IllegalArgumentException("Invalid payload: " + payload); } - - // String jsonStr; - // try { - // jsonStr = AccessController.doPrivileged((PrivilegedExceptionAction) () -> gson.toJson(inputParams)); - // } catch (PrivilegedActionException e) { - // throw new IllegalArgumentException("wrong connector"); - // } XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload); ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); @@ -126,7 +119,9 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m try { // getMlModelOutputs() returns a list or collection. // Adding null check for modelTensorOutput - if (modelTensorOutput != null && !modelTensorOutput.getMlModelOutputs().isEmpty()) { + if (modelTensorOutput != null + && modelTensorOutput.getMlModelOutputs() != null + && !modelTensorOutput.getMlModelOutputs().isEmpty()) { // getMlModelOutputs() returns a list of ModelTensors // accessing the first element. // TODO currently remote model only return single tensor, might need to processor multiple tensors later @@ -182,7 +177,7 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m throw new RuntimeException("Model outputs are null or empty."); } } catch (Exception e) { - throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + throw new RuntimeException(e.getMessage()); } return modelOutputValue; } @@ -191,17 +186,19 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam try (XContentBuilder builder = XContentFactory.jsonBuilder()) { String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); - try { - if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { - return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); - } else { + if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { + return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); + } else if (modelOutputFieldName == null || modelTensorOutputMap == null) { + return modelTensorOutputMap; + } else { + try { return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); - } - } catch (Exception e) { - if (ignoreMissing) { - return modelTensorOutputMap; - } else { - throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName, e); + } catch (Exception e) { + if (ignoreMissing) { + return modelTensorOutputMap; + } else { + throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName, e); + } } } } catch (Exception e) { @@ -323,4 +320,4 @@ default String convertToDotPath(String path) { return path.replaceAll("\\[(\\d+)\\]", "$1\\.").replaceAll("\\['(.*?)']", "$1\\.").replaceAll("^\\$", "").replaceAll("\\.$", ""); } -} \ No newline at end of file +} diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java index 30c9a31ada..7ca077a82f 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -5,6 +5,9 @@ package org.opensearch.ml.processor; import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FULL_RESPONSE_PATH; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FUNCTION_NAME; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.MODEL_INPUT; import java.util.ArrayList; import java.util.HashMap; @@ -26,7 +29,6 @@ public class MLInferenceIngestProcessorFactoryTests extends OpenSearchTestCase { private Client client; @Mock private ScriptService scriptService; - @Mock private NamedXContentRegistry xContentRegistry; @@ -46,6 +48,34 @@ public void testCreateRequiredFields() throws Exception { assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); } + public void testCreateLocalModelProcessor() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config); + assertNotNull(mLInferenceIngestProcessor); + assertEquals(mLInferenceIngestProcessor.getTag(), processorTag); + assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); + } + public void testCreateNoFieldPresent() throws Exception { Map config = new HashMap<>(); try { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index 5f96f23368..47c2e4d9d9 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -9,6 +9,7 @@ import static org.mockito.Mockito.*; import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; +import java.io.IOException; import java.nio.ByteBuffer; import java.time.ZonedDateTime; import java.util.ArrayList; @@ -20,7 +21,6 @@ import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.mockito.ArgumentCaptor; import org.mockito.Mock; import org.mockito.Mockito; @@ -79,51 +79,20 @@ public void setup() { } private MLInferenceIngestProcessor createMLInferenceProcessor( - String model_id, - Map model_config, - List> input_map, - List> output_map, - boolean ignoreMissing, - boolean ignoreFailure - ) { - return createMLInferenceProcessor2( - model_id, - input_map, - output_map, - model_config, - RANDOM_MULTIPLIER, - PROCESSOR_TAG, - DESCRIPTION, - ignoreMissing, - "remote", - false, - ignoreFailure, - false, - null, - scriptService, - client, - xContentRegistry - ); - } - - private MLInferenceIngestProcessor createMLInferenceProcessor2( String modelId, List> inputMaps, List> outputMaps, Map modelConfigMaps, - int maxPredictionTask, - String tag, - String description, boolean ignoreMissing, String functionName, boolean fullResponsePath, boolean ignoreFailure, boolean override, - String modelInput, - ScriptService scriptService, - Client client, - NamedXContentRegistry xContentRegistry + String modelInput ) { + functionName = functionName != null ? functionName : "remote"; + modelInput = modelInput != null ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }"; + return new MLInferenceIngestProcessor( modelId, inputMaps, @@ -145,7 +114,18 @@ private MLInferenceIngestProcessor createMLInferenceProcessor2( } public void testExecute_Exception() throws Exception { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); try { IngestDocument document = processor.execute(ingestDocument); } catch (UnsupportedOperationException e) { @@ -159,9 +139,20 @@ public void testExecute_Exception() throws Exception { */ public void testExecute_nestedObjectStringDocumentSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.chunk"); + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -185,69 +176,80 @@ public void testExecute_nestedObjectStringDocumentSuccess() { * test nested object document with array of Map, * the value Object is a Map */ -// @Ignore -// public void testExecute_nestedObjectMapDocumentSuccess() { -// List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); -// -// MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); -// ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); -// ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); -// ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); -// -// doAnswer(invocation -> { -// ActionListener actionListener = invocation.getArgument(2); -// actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); -// return null; -// }).when(client).execute(any(), any(), any()); -// -// ArrayList childDocuments = new ArrayList<>(); -// Map childDocument1Text = new HashMap<>(); -// childDocument1Text.put("text", "this is first"); -// Map childDocument1 = new HashMap<>(); -// childDocument1.put("chunk", childDocument1Text); -// -// Map childDocument2 = new HashMap<>(); -// Map childDocument2Text = new HashMap<>(); -// childDocument2Text.put("text", "this is second"); -// childDocument2.put("chunk", childDocument2Text); -// -// childDocuments.add(childDocument1); -// childDocuments.add(childDocument2); -// -// Map sourceAndMetadata = new HashMap<>(); -// sourceAndMetadata.put("chunks", childDocuments); -// -// IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); -// processor.execute(nestedObjectIngestDocument, handler); -// -// // match input dataset -// -// ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); -// verify(client).execute(any(), argumentCaptor.capture(), any()); -// -// Map inputParameters = new HashMap<>(); -// ArrayList embedding_text = new ArrayList<>(); -// embedding_text.add("this is first"); -// embedding_text.add("this is second"); -// inputParameters.put("inputs", modelExecutor.toString(embedding_text)); -// -// MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor -// .getRemoteModelInferenceRequest(inputParameters, "model1"); -// MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); -// -// RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest -// .getMlInput() -// .getInputDataset(); -// RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); -// -// assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); -// -// // match document -// sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3)))); -// IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); -// verify(handler).accept(eq(ingestDocument1), isNull()); -// assertEquals(nestedObjectIngestDocument, ingestDocument1); -// } + public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ArrayList childDocuments = new ArrayList<>(); + Map childDocument1Text = new HashMap<>(); + childDocument1Text.put("text", "this is first"); + Map childDocument1 = new HashMap<>(); + childDocument1.put("chunk", childDocument1Text); + + Map childDocument2 = new HashMap<>(); + Map childDocument2Text = new HashMap<>(); + childDocument2Text.put("text", "this is second"); + childDocument2.put("chunk", childDocument2Text); + + childDocuments.add(childDocument1); + childDocuments.add(childDocument2); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("chunks", childDocuments); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match input dataset + + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + verify(client).execute(any(), argumentCaptor.capture(), any()); + + Map inputParameters = new HashMap<>(); + ArrayList embedding_text = new ArrayList<>(); + embedding_text.add("this is first"); + embedding_text.add("this is second"); + inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; + + MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor + .getRemoteModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); + MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); + + RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest + .getMlInput() + .getInputDataset(); + RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); + + assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); + + // match document + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3)))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } public void testExecute_jsonPathWithMissingLeaves() { @@ -273,61 +275,84 @@ public void testExecute_jsonPathWithMissingLeaves() { * test nested object document with array of Map, * the value Object is a also a nested object, */ -// public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { -// List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); -// -// MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); -// ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); -// ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); -// ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); -// -// doAnswer(invocation -> { -// ActionListener actionListener = invocation.getArgument(2); -// actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); -// return null; -// }).when(client).execute(any(), any(), any()); -// -// Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); -// -// IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); -// processor.execute(nestedObjectIngestDocument, handler); -// -// // match input dataset -// ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); -// verify(client).execute(any(), argumentCaptor.capture(), any()); -// -// Map inputParameters = new HashMap<>(); -// ArrayList embedding_text = new ArrayList<>(); -// embedding_text.add("this is first"); -// embedding_text.add("this is second"); -// embedding_text.add("this is third"); -// embedding_text.add("this is fourth"); -// inputParameters.put("inputs", modelExecutor.toString(embedding_text)); -// -// MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor -// .getRemoteModelInferenceRequest(inputParameters, "model1"); -// MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); -// -// RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest -// .getMlInput() -// .getInputDataset(); -// RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); -// -// assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); -// -// // match document -// sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4)))); -// IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); -// verify(handler).accept(eq(ingestDocument1), isNull()); -// assertEquals(nestedObjectIngestDocument, ingestDocument1); -// -// } + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() throws IOException { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // match input dataset + ArgumentCaptor argumentCaptor = ArgumentCaptor.forClass(MLPredictionTaskRequest.class); + verify(client).execute(any(), argumentCaptor.capture(), any()); + + Map inputParameters = new HashMap<>(); + ArrayList embedding_text = new ArrayList<>(); + embedding_text.add("this is first"); + embedding_text.add("this is second"); + embedding_text.add("this is third"); + embedding_text.add("this is fourth"); + inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; + + MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor + .getRemoteModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); + MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); + + RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest + .getMlInput() + .getInputDataset(); + RemoteInferenceInputDataSet actualRemoteInputDataset = (RemoteInferenceInputDataSet) actualRequest.getMlInput().getInputDataset(); + + assertEquals(expectedRemoteInputDataset.getParameters().get("inputs"), actualRemoteInputDataset.getParameters().get("inputs")); + + // match document + sourceAndMetadata.put(DEFAULT_OUTPUT_FIELD_NAME, List.of(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4)))); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + + } public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -360,7 +385,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -394,7 +430,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL } public void testExecute_InferenceException() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); when(client.execute(any(), any())).thenThrow(new RuntimeException("Executing Model failed with exception")); try { processor.execute(ingestDocument, handler); @@ -404,7 +451,18 @@ public void testExecute_InferenceException() { } public void testExecute_InferenceOnFailure() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); RuntimeException inferenceFailure = new RuntimeException("Executing Model failed with exception"); doAnswer(invocation -> { @@ -424,7 +482,18 @@ public void testExecute_AppendFieldValueExceptionOnResponse() throws Exception { String originalOutPutFieldName = "response1"; output.put("text_embedding", originalOutPutFieldName); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -458,7 +527,18 @@ public void testExecute_whenInputFieldNotFound_ExceptionWithIgnoreMissingFalse() Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -478,7 +558,18 @@ public void testExecute_whenInputFieldNotFound_SuccessWithIgnoreMissingTrue() { Map output = new HashMap<>(); output.put("text_embedding", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); } @@ -496,7 +587,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingFalse() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -518,7 +620,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingTrue() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); @@ -540,7 +653,18 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin ObjectMapper mapper = mock(ObjectMapper.class); when(mapper.readValue(Mockito.anyString(), eq(Object.class))).thenThrow(JsonProcessingException.class); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -550,8 +674,30 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin } public void testExecute_NoModelInput_Exception() { - MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor("model1", null, null, null, true, false); - MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); + MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); Map sourceAndMetadata = new HashMap<>(); IngestDocument emptyIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -569,7 +715,18 @@ public void testExecute_NoModelInput_Exception() { } public void testExecute_AppendModelOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -592,7 +749,18 @@ public void testExecute_AppendModelOutputSuccess() { } public void testExecute_SingleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); Float[] value = new Float[] { 1.0f, 2.0f, 3.0f }; List outputs = new ArrayList<>(); @@ -627,7 +795,18 @@ public void testExecute_SingleTensorInDataOutputSuccess() { } public void testExecute_MultipleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); List outputs = new ArrayList<>(); Float[] value = new Float[] { 1.0f }; @@ -689,7 +868,18 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -720,7 +910,18 @@ public void testExecute_getModelOutputFieldWithDotPathSuccess() { output.put("language_identification", "response.language"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", List.of("en", "en"), "score", "0.9876"))) @@ -752,7 +953,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathSuccess() { output.put("language_identification", "response.lan"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -782,7 +994,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathException() { output.put("response.lan", "language_identification"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -817,7 +1040,18 @@ public void testExecute_getModelOutputFieldInNestedWithInvalidDotPathException() output.put("chunks.*.chunk.text.*.context_embedding", "response.language1"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -852,7 +1086,18 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { output.put("key1", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + true, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -867,17 +1112,13 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { }).when(client).execute(any(), any(), any()); processor.execute(ingestDocument, handler); - verify(handler) - .accept( - eq(null), - argThat( - exception -> exception - .getMessage() - .equals( - "document already has field name key1. Not allow to overwrite the same field name, please check output_map." - ) - ) - ); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", ImmutableMap.of("language", "en", "score", "0.9876")); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); } public void testExecute_documentNotExistedFieldNameException() { @@ -891,7 +1132,18 @@ public void testExecute_documentNotExistedFieldNameException() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -901,7 +1153,18 @@ public void testExecute_documentNotExistedFieldNameException() { public void testExecute_nestedDocumentNotExistedFieldNameException() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context1"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -920,7 +1183,18 @@ public void testExecute_getModelOutputFieldDifferentLengthException() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -956,7 +1230,18 @@ public void testExecute_getModelOutputFieldDifferentLengthIgnoreFailureSuccess() List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -986,7 +1271,18 @@ public void testExecute_getMlModelTensorsIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -1013,7 +1309,18 @@ public void testExecute_getMlModelTensorsIsNullIgnoreFailure() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -1034,7 +1341,18 @@ public void testExecute_modelTensorOutputIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1046,7 +1364,11 @@ public void testExecute_modelTensorOutputIsNull() { IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); - verify(handler).accept(eq(null), argThat(exception -> exception.getMessage().equals("model inference output cannot be null"))); + verify(handler) + .accept( + eq(null), + argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Model outputs are null or empty.")) + ); } @@ -1055,7 +1377,18 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1070,6 +1403,108 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); } + /** + * Test processor configuration with nested object document + * and array of Map, where the value Object is a List + */ + public void testExecute_localModelSuccess() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "_ingest._value.title"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("_ingest._value.title_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output + List modelPredictionOutput = Arrays.asList(1, 2, 3, 4); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "inference_results", + Arrays + .asList( + ImmutableMap + .of( + "output", + Arrays.asList(ImmutableMap.of("name", "sentence_embedding", "data", modelPredictionOutput)) + ) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + // Setting up the ingest document + Map sourceAndMetadata = new HashMap<>(); + List> books = new ArrayList<>(); + Map book1 = new HashMap<>(); + book1.put("title", Arrays.asList("first book")); + book1.put("description", "This is first book"); + Map book2 = new HashMap<>(); + book2.put("title", Arrays.asList("second book")); + book2.put("description", "This is second book"); + books.add(book1); + books.add(book2); + sourceAndMetadata.put("books", books); + + Map ingestMetadata = new HashMap<>(); + ingestMetadata.put("pipeline", "test_pipeline"); + ingestMetadata.put("timestamp", ZonedDateTime.now()); + Map ingestValue = new HashMap<>(); + ingestValue.put("title", Arrays.asList("first book")); + ingestValue.put("description", "This is first book"); + ingestMetadata.put("_value", ingestValue); + sourceAndMetadata.put("_ingest", ingestMetadata); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // Validate the document + List> updatedBooks = new ArrayList<>(); + Map updatedBook1 = new HashMap<>(); + updatedBook1.put("title", Arrays.asList("first book")); + updatedBook1.put("description", "This is first book"); + updatedBook1.put("title_embedding", modelPredictionOutput); + Map updatedBook2 = new HashMap<>(); + updatedBook2.put("title", Arrays.asList("second book")); + updatedBook2.put("description", "This is second book"); + updatedBook2.put("title_embedding", modelPredictionOutput); + updatedBooks.add(updatedBook1); + updatedBooks.add(updatedBook2); + sourceAndMetadata.put("books", updatedBooks); + + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } + public void testParseGetDataInTensor_IntegerDataType() { ModelTensor mockTensor = mock(ModelTensor.class); when(mockTensor.getDataType()).thenReturn(MLResultDataType.INT8); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index 1937b8c496..3303a2ac75 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -5,6 +5,8 @@ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestHelper.makeRequest; import java.io.IOException; @@ -15,8 +17,15 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Assert; import org.junit.Before; +import org.junit.Ignore; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -26,6 +35,8 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase { private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private String openAIChatModelId; private String bedrockEmbeddingModelId; + + private String localModelId; private final String completionModelConnectorEntity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -350,6 +361,108 @@ public void testMLInferenceProcessorWithForEachProcessor() throws Exception { Assert.assertEquals(1536, embedding2.size()); } + @Ignore + public void testMLInferenceProcessorWithLocalModel() throws Exception { + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String indexName = "my_books"; + String pipelineName = "my_books_text_embedding_pipeline"; + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"" + + pipelineName + + "\"\n" + + " }\n" + + " },\n" + + " \"mappings\": {\n" + + " \"properties\": {\n" + + " \"books\": {\n" + + " \"type\": \"nested\",\n" + + " \"properties\": {\n" + + " \"title_embedding\": {\n" + + " \"type\": \"float\"\n" + + " },\n" + + " \"title\": {\n" + + " \"type\": \"text\"\n" + + " },\n" + + " \"description\": {\n" + + " \"type\": \"text\"\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + "}"; + createIndex(indexName, createIndexRequestBody); + + String createPipelineRequestBody = "{\n" + + " \"description\": \"test embeddings\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"foreach\": {\n" + + " \"field\": \"books\",\n" + + " \"processor\": {\n" + + " \"ml_inference\": {\n" + + " \"model_id\": \"" + + localModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"_ingest._value.title\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"_ingest._value.title_embedding\": \"$.embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + createPipelineProcessor(createPipelineRequestBody, pipelineName); + + String uploadDocumentRequestBody = "{\n" + + " \"books\": [{\n" + + " \"title\": \"first book\",\n" + + " \"description\": \"This is first book\"\n" + + " },\n" + + " {\n" + + " \"title\": \"second book\",\n" + + " \"description\": \"This is second book\"\n" + + " }\n" + + " ]\n" + + "}"; + uploadDocument(indexName, "1", uploadDocumentRequestBody); + Map document = getDocument(indexName, "1"); + + List embeddingList = JsonPath.parse(document).read("_source.books[*].title_embedding"); + Assert.assertEquals(2, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.books[0].title_embedding"); + Assert.assertEquals(1536, embedding1.size()); + List embedding2 = JsonPath.parse(document).read("_source.books[1].title_embedding"); + Assert.assertEquals(1536, embedding2.size()); + } + protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception { Response pipelineCreateResponse = TestHelper .makeRequest( @@ -378,7 +491,6 @@ protected void createIndex(String indexName, String requestBody) throws Exceptio protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); - request.setJsonEntity(jsonBody); client().performRequest(request); } @@ -390,4 +502,24 @@ protected Map getDocument(final String index, final String docId) throws Excepti return parseResponseToMap(docResponse); } + protected MLRegisterModelInput registerModelInput() { + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } + } From ec85479699a4eb3fca12e1c65c8261d08f52b707 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jun 2024 10:57:56 -0500 Subject: [PATCH 3/5] address comments and add more UT/IT Signed-off-by: Bhavana Ramaram --- .../processor/MLInferenceIngestProcessor.java | 7 +- .../ml/processor/ModelExecutor.java | 6 +- .../MLInferenceIngestProcessorTests.java | 234 +++++++++++++++++- .../RestMLInferenceIngestProcessorIT.java | 218 +++++++++++----- 4 files changed, 391 insertions(+), 74 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 3362d6794e..491128636d 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -16,6 +16,8 @@ import java.util.Set; import java.util.function.BiConsumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.GroupedActionListener; import org.opensearch.client.Client; @@ -46,6 +48,8 @@ */ public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor { + private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class); + public static final String DOT_SYMBOL = "."; private final InferenceProcessorAttributes inferenceProcessorAttributes; private final boolean ignoreMissing; @@ -219,6 +223,7 @@ private void processPredictions( } } if (!override && existingFields == dotPaths.size()) { + logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName); newOutputMapping.remove(newDocumentFieldName); } } @@ -251,7 +256,7 @@ private void processPredictions( for (String k : inputMapKeys) { inputMappings.put(k, modelParameters.get(k)); } - ActionRequest request = getRemoteModelInferenceRequest( + ActionRequest request = getMLModelInferenceRequest( xContentRegistry, modelParameters, modelConfigs, diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 57704896af..792d6bfa8c 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -57,7 +57,7 @@ public interface ModelExecutor { * @return an ActionRequest instance for remote model inference * @throws IllegalArgumentException if the input parameters are null */ - default ActionRequest getRemoteModelInferenceRequest( + default ActionRequest getMLModelInferenceRequest( NamedXContentRegistry xContentRegistry, Map parameters, Map modelConfigs, @@ -73,7 +73,6 @@ default ActionRequest getRemoteModelInferenceRequest( if (functionNameStr != null) { functionName = FunctionName.from(functionNameStr); } - // RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); Map inputParams = new HashMap<>(); if (FunctionName.REMOTE == functionName) { @@ -83,7 +82,6 @@ default ActionRequest getRemoteModelInferenceRequest( } String payload = modelInput; - // payload = fillNullParameters(parameters, payload); StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}"); payload = modelConfigSubstitutor.replace(payload); StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}"); @@ -98,7 +96,6 @@ default ActionRequest getRemoteModelInferenceRequest( ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); MLInput mlInput = MLInput.parse(parser, functionName.name()); - // MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); return new MLPredictionTaskRequest(modelId, mlInput); @@ -186,6 +183,7 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam try (XContentBuilder builder = XContentFactory.jsonBuilder()) { String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); + System.out.println("output value" + modelOutputJsonStr); if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); } else if (modelOutputFieldName == null || modelTensorOutputMap == null) { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index 47c2e4d9d9..b5b8a19731 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -14,6 +14,7 @@ import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -234,7 +235,7 @@ public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -319,7 +320,7 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSucce String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -574,6 +575,30 @@ public void testExecute_whenInputFieldNotFound_SuccessWithIgnoreMissingTrue() { processor.execute(ingestDocument, handler); } + public void testExecute_localModelInputFieldNotFound_SuccessWithIgnoreMissingTrue() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + Map model_config = new HashMap<>(); + model_config.put("return_number", "true"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "text_embedding", + true, + false, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + processor.execute(ingestDocument, handler); + } + public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingFalse() { List> inputMap = new ArrayList<>(); Map input = new HashMap<>(); @@ -699,6 +724,32 @@ public void testExecute_NoModelInput_Exception() { null ); + MLInferenceIngestProcessor localModelProcessorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "text_embedding", + false, + false, + false, + null + ); + + MLInferenceIngestProcessor localModelProcessorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "text_embedding", + false, + false, + false, + null + ); + Map sourceAndMetadata = new HashMap<>(); IngestDocument emptyIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); try { @@ -712,6 +763,17 @@ public void testExecute_NoModelInput_Exception() { assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); } + try { + localModelProcessorIgnoreMissingTrue.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + try { + localModelProcessorIgnoreMissingFalse.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + } public void testExecute_AppendModelOutputSuccess() { @@ -1304,6 +1366,44 @@ public void testExecute_getMlModelTensorsIsNull() { } + public void testExecute_localMLModelTensorsIsNull() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "text_embedding", + true, + false, + false, + null + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + + verify(handler) + .accept( + eq(null), + argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Output tensors are null or empty.")) + ); + + } + public void testExecute_getMlModelTensorsIsNullIgnoreFailure() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); @@ -1505,6 +1605,136 @@ public void testExecute_localModelSuccess() { assertEquals(nestedObjectIngestDocument, ingestDocument1); } + public void testExecute_localSparseEncodingModelMultipleModelTensors() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results.*.output.*.dataAsMap.response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "sparse_encoding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output with simple values + List> modelEmbeddings = new ArrayList<>(); + Map embedding = ImmutableMap.of("response", Arrays.asList(1.0, 2.0, 3.0, 4.0)); + for (int i = 1; i <= 4; i++) { + modelEmbeddings.add(embedding); + } + + List modelTensors = new ArrayList<>(); + for (Map embeddings : modelEmbeddings) { + modelTensors.add(ModelTensor.builder().dataAsMap(embeddings).build()); + } + + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(Collections.singletonList(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + IngestDocument ingestDocument = new IngestDocument(getNestedObjectWithAnotherNestedObjectSource(), new HashMap<>()); + processor.execute(ingestDocument, handler); + verify(handler).accept(eq(ingestDocument), isNull()); + + List> chunks = (List>) ingestDocument.getFieldValue("chunks", List.class); + + List> firstChunkTexts = (List>) ((Map) chunks.get(0).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(0).get("response"), firstChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(1).get("response"), firstChunkTexts.get(1).get("context_embedding")); + + List> secondChunkTexts = (List>) ((Map) chunks.get(1).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(2).get("response"), secondChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(3).get("response"), secondChunkTexts.get(1).get("context_embedding")); + + } + + public void testExecute_localModelOutputIsNullIgnoreFailureSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + + public void testExecute_localModelTensorsIsNullIgnoreFailure() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + public void testParseGetDataInTensor_IntegerDataType() { ModelTensor mockTensor = mock(ModelTensor.class); when(mockTensor.getDataType()).thenReturn(MLResultDataType.INT8); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index 3303a2ac75..f8d623fc74 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -17,7 +17,6 @@ import org.apache.hc.core5.http.message.BasicHeader; import org.junit.Assert; import org.junit.Before; -import org.junit.Ignore; import org.opensearch.client.Request; import org.opensearch.client.Response; import org.opensearch.ml.common.FunctionName; @@ -361,8 +360,8 @@ public void testMLInferenceProcessorWithForEachProcessor() throws Exception { Assert.assertEquals(1536, embedding2.size()); } - @Ignore - public void testMLInferenceProcessorWithLocalModel() throws Exception { + public void testMLInferenceProcessorLocalModelObjectField() throws Exception { + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); waitForTask(taskId, MLTaskState.COMPLETED); getTask(client(), taskId, response -> { @@ -378,89 +377,173 @@ public void testMLInferenceProcessorWithLocalModel() throws Exception { } }); - String indexName = "my_books"; - String pipelineName = "my_books_text_embedding_pipeline"; + String createPipelineRequestBody = "{\n" + + " \"description\": \"test ml model ingest processor\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"diary\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"diary_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; String createIndexRequestBody = "{\n" + " \"settings\": {\n" + " \"index\": {\n" - + " \"default_pipeline\": \"" - + pipelineName - + "\"\n" - + " }\n" - + " },\n" - + " \"mappings\": {\n" - + " \"properties\": {\n" - + " \"books\": {\n" - + " \"type\": \"nested\",\n" - + " \"properties\": {\n" - + " \"title_embedding\": {\n" - + " \"type\": \"float\"\n" - + " },\n" - + " \"title\": {\n" - + " \"type\": \"text\"\n" - + " },\n" - + " \"description\": {\n" - + " \"type\": \"text\"\n" - + " }\n" - + " }\n" - + " }\n" + + " \"default_pipeline\": \"diary_embedding_pipeline\"\n" + " }\n" + " }\n" - + "}"; - createIndex(indexName, createIndexRequestBody); + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"weather\": \"rainy\"\n" + + " }"; + String index_name = "daily_index"; + createPipelineProcessor(createPipelineRequestBody, "diary_embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + List embeddingList = JsonPath.parse(document).read("_source.diary_embedding"); + Assert.assertEquals(2, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.diary_embedding[0]"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.42101282, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.diary_embedding[1]"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49191704, (Double) embedding2.get(0), 0.005); + } + + // TODO: add tests for other local model types such as sparse/cross encoders + public void testMLInferenceProcessorLocalModelNestedField() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); String createPipelineRequestBody = "{\n" - + " \"description\": \"test embeddings\",\n" + + " \"description\": \"ingest reviews and generate embedding\",\n" + " \"processors\": [\n" + " {\n" - + " \"foreach\": {\n" - + " \"field\": \"books\",\n" - + " \"processor\": {\n" - + " \"ml_inference\": {\n" - + " \"model_id\": \"" - + localModelId + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + "\",\n" - + " \"input_map\": [\n" - + " {\n" - + " \"input\": \"_ingest._value.title\"\n" - + " }\n" - + " ],\n" - + " \"output_map\": [\n" - + " {\n" - + " \"_ingest._value.title_embedding\": \"$.embedding\"\n" - + " }\n" - + " ],\n" - + " \"ignore_missing\": false,\n" - + " \"ignore_failure\": false\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"book.*.chunk.text.*.context\"\n" + " }\n" - + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"book.*.chunk.text.*.context_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": true,\n" + + " \"ignore_failure\": true\n" + " }\n" + " }\n" + " ]\n" + "}"; - createPipelineProcessor(createPipelineRequestBody, pipelineName); + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; String uploadDocumentRequestBody = "{\n" - + " \"books\": [{\n" - + " \"title\": \"first book\",\n" - + " \"description\": \"This is first book\"\n" - + " },\n" - + " {\n" - + " \"title\": \"second book\",\n" - + " \"description\": \"This is second book\"\n" - + " }\n" - + " ]\n" + + " \"book\": [\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the first part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the second part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the third part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the fourth part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + "}"; - uploadDocument(indexName, "1", uploadDocumentRequestBody); - Map document = getDocument(indexName, "1"); + String index_name = "book_index"; + createPipelineProcessor(createPipelineRequestBody, "embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); - List embeddingList = JsonPath.parse(document).read("_source.books[*].title_embedding"); - Assert.assertEquals(2, embeddingList.size()); + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); - List embedding1 = JsonPath.parse(document).read("_source.books[0].title_embedding"); - Assert.assertEquals(1536, embedding1.size()); - List embedding2 = JsonPath.parse(document).read("_source.books[1].title_embedding"); - Assert.assertEquals(1536, embedding2.size()); + List embeddingList = JsonPath.parse(document).read("_source.book[*].chunk.text[*].context_embedding"); + Assert.assertEquals(4, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.book[0].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.48988956, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.book[0].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49552172, (Double) embedding2.get(0), 0.005); + + List embedding3 = JsonPath.parse(document).read("_source.book[1].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding3.size()); + Assert.assertEquals(0.5004309, (Double) embedding3.get(0), 0.005); + + List embedding4 = JsonPath.parse(document).read("_source.book[1].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding4.size()); + Assert.assertEquals(0.47907734, (Double) embedding4.get(0), 0.005); } protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception { @@ -502,7 +585,8 @@ protected Map getDocument(final String index, final String docId) throws Excepti return parseResponseToMap(docResponse); } - protected MLRegisterModelInput registerModelInput() { + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + MLModelConfig modelConfig = TextEmbeddingModelConfig .builder() .modelType("bert") From 300c54817a6b6b915bb23a859988c94d309cc039 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jun 2024 13:16:55 -0500 Subject: [PATCH 4/5] remove logs Signed-off-by: Bhavana Ramaram --- .../ml/processor/ModelExecutor.java | 1 - .../MLInferenceIngestProcessorTests.java | 21 +++++++++++++++---- 2 files changed, 17 insertions(+), 5 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 792d6bfa8c..ff46c13f62 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -183,7 +183,6 @@ default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldNam try (XContentBuilder builder = XContentFactory.jsonBuilder()) { String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); - System.out.println("output value" + modelOutputJsonStr); if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); } else if (modelOutputFieldName == null || modelTensorOutputMap == null) { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index b5b8a19731..203392eb75 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -1367,9 +1367,15 @@ public void testExecute_getMlModelTensorsIsNull() { } public void testExecute_localMLModelTensorsIsNull() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); - List> outputMap = getOutputMapsForNestedObjectChunks(); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); MLInferenceIngestProcessor processor = createMLInferenceProcessor( "model1", @@ -1381,7 +1387,7 @@ public void testExecute_localMLModelTensorsIsNull() { true, false, false, - null + "{ \"text_docs\": ${ml_inference.text_docs} }" ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -1399,7 +1405,14 @@ public void testExecute_localMLModelTensorsIsNull() { verify(handler) .accept( eq(null), - argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Output tensors are null or empty.")) + argThat( + exception -> exception + .getMessage() + .equals( + "An unexpected error occurred: model inference output " + + "cannot find such json path: $.inference_results[0].output[0].data" + ) + ) ); } From a4f711b58aa370d74b5a0065925c83cb2ba60006 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jun 2024 14:34:33 -0500 Subject: [PATCH 5/5] fix failing ITs Signed-off-by: Bhavana Ramaram --- .../org/opensearch/ml/processor/MLInferenceIngestProcessor.java | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index 491128636d..b19853e02c 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -491,7 +491,7 @@ public MLInferenceIngestProcessor create( .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); String modelInput = ConfigurationUtils .readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }"); - boolean defaultValue = !functionName.equals("remote"); + boolean defaultValue = !functionName.equalsIgnoreCase("remote"); boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue); boolean ignoreFailure = ConfigurationUtils