From 7cd52915d04d8ac7ddb6e37a74a256603587ce69 Mon Sep 17 00:00:00 2001 From: Bhavana Ramaram Date: Tue, 11 Jun 2024 15:11:48 -0500 Subject: [PATCH] ml inference ingest processor support for local models (#2508) * ml inference ingest processor support for local models Signed-off-by: Bhavana Ramaram --- .../ml/plugin/MachineLearningPlugin.java | 5 +- .../processor/MLInferenceIngestProcessor.java | 201 +++-- .../ml/processor/ModelExecutor.java | 84 +- ...LInferenceIngestProcessorFactoryTests.java | 36 +- .../MLInferenceIngestProcessorTests.java | 849 ++++++++++++++++-- .../RestMLInferenceIngestProcessorIT.java | 218 ++++- 6 files changed, 1265 insertions(+), 128 deletions(-) diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index e9a79236b1..6d808c64bb 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -1006,7 +1006,10 @@ public void loadExtensions(ExtensionLoader loader) { public Map getProcessors(org.opensearch.ingest.Processor.Parameters parameters) { Map processors = new HashMap<>(); processors - .put(MLInferenceIngestProcessor.TYPE, new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client)); + .put( + MLInferenceIngestProcessor.TYPE, + new MLInferenceIngestProcessor.Factory(parameters.scriptService, parameters.client, xContentRegistry) + ); return Collections.unmodifiableMap(processors); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java index c06f32803c..b19853e02c 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceIngestProcessor.java @@ -6,25 +6,31 @@ import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import java.io.IOException; import java.util.ArrayList; import java.util.Collection; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Set; import java.util.function.BiConsumer; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.GroupedActionListener; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; import org.opensearch.core.common.Strings; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.AbstractProcessor; import org.opensearch.ingest.ConfigurationUtils; import org.opensearch.ingest.IngestDocument; import org.opensearch.ingest.Processor; import org.opensearch.ingest.ValueSource; -import org.opensearch.ml.common.output.model.ModelTensorOutput; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.transport.MLTaskResponse; import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction; import org.opensearch.ml.common.utils.StringUtils; @@ -42,10 +48,16 @@ */ public class MLInferenceIngestProcessor extends AbstractProcessor implements ModelExecutor { + private static final Logger logger = LogManager.getLogger(MLInferenceIngestProcessor.class); + public static final String DOT_SYMBOL = "."; private final InferenceProcessorAttributes inferenceProcessorAttributes; private final boolean ignoreMissing; + private final String functionName; + private final boolean fullResponsePath; private final boolean ignoreFailure; + private final boolean override; + private final String modelInput; private final ScriptService scriptService; private static Client client; public static final String TYPE = "ml_inference"; @@ -53,9 +65,14 @@ public class MLInferenceIngestProcessor extends AbstractProcessor implements Mod // allow to ignore a field from mapping is not present in the document, and when the outfield is not found in the // prediction outcomes, return the whole prediction outcome by skipping filtering public static final String IGNORE_MISSING = "ignore_missing"; + public static final String OVERRIDE = "override"; + public static final String FUNCTION_NAME = "function_name"; + public static final String FULL_RESPONSE_PATH = "full_response_path"; + public static final String MODEL_INPUT = "model_input"; // At default, ml inference processor allows maximum 10 prediction tasks running in parallel // it can be overwritten using max_prediction_tasks when creating processor public static final int DEFAULT_MAX_PREDICTION_TASKS = 10; + private final NamedXContentRegistry xContentRegistry; private Configuration suppressExceptionConfiguration = Configuration .builder() @@ -71,9 +88,14 @@ protected MLInferenceIngestProcessor( String tag, String description, boolean ignoreMissing, + String functionName, + boolean fullResponsePath, boolean ignoreFailure, + boolean override, + String modelInput, ScriptService scriptService, - Client client + Client client, + NamedXContentRegistry xContentRegistry ) { super(tag, description); this.inferenceProcessorAttributes = new InferenceProcessorAttributes( @@ -84,9 +106,14 @@ protected MLInferenceIngestProcessor( maxPredictionTask ); this.ignoreMissing = ignoreMissing; + this.functionName = functionName; + this.fullResponsePath = fullResponsePath; this.ignoreFailure = ignoreFailure; + this.override = override; + this.modelInput = modelInput; this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -162,10 +189,48 @@ private void processPredictions( List> processOutputMap, int inputMapIndex, int inputMapSize - ) { + ) throws IOException { Map modelParameters = new HashMap<>(); + Map modelConfigs = new HashMap<>(); + if (inferenceProcessorAttributes.getModelConfigMaps() != null) { modelParameters.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + modelConfigs.putAll(inferenceProcessorAttributes.getModelConfigMaps()); + } + + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + + Map> newOutputMapping = new HashMap<>(); + if (processOutputMap != null) { + + Map outputMapping = processOutputMap.get(inputMapIndex); + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + newOutputMapping.put(newDocumentFieldName, dotPathsInArray); + } + + for (Map.Entry entry : outputMapping.entrySet()) { + String newDocumentFieldName = entry.getKey(); + List dotPaths = newOutputMapping.get(newDocumentFieldName); + + int existingFields = 0; + for (String path : dotPaths) { + if (ingestDocument.hasField(path)) { + existingFields++; + } + } + if (!override && existingFields == dotPaths.size()) { + logger.debug("{} already exists in the ingest document. Removing it from output mapping", newDocumentFieldName); + newOutputMapping.remove(newDocumentFieldName); + } + } + if (newOutputMapping.size() == 0) { + batchPredictionListener.onResponse(null); + return; + } } // when no input mapping is provided, default to read all fields from documents as model input if (inputMapSize == 0) { @@ -184,15 +249,30 @@ private void processPredictions( } } - ActionRequest request = getRemoteModelInferenceRequest(modelParameters, inferenceProcessorAttributes.getModelId()); + Set inputMapKeys = new HashSet<>(modelParameters.keySet()); + inputMapKeys.removeAll(modelConfigs.keySet()); + + Map inputMappings = new HashMap<>(); + for (String k : inputMapKeys) { + inputMappings.put(k, modelParameters.get(k)); + } + ActionRequest request = getMLModelInferenceRequest( + xContentRegistry, + modelParameters, + modelConfigs, + inputMappings, + inferenceProcessorAttributes.getModelId(), + functionName, + modelInput + ); client.execute(MLPredictionTaskAction.INSTANCE, request, new ActionListener<>() { @Override public void onResponse(MLTaskResponse mlTaskResponse) { - ModelTensorOutput modelTensorOutput = (ModelTensorOutput) mlTaskResponse.getOutput(); + MLOutput mlOutput = mlTaskResponse.getOutput(); if (processOutputMap == null || processOutputMap.isEmpty()) { - appendFieldValue(modelTensorOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); + appendFieldValue(mlOutput, null, DEFAULT_OUTPUT_FIELD_NAME, ingestDocument); } else { // outMapping serves as a filter to modelTensorOutput, the fields that are not specified // in the outputMapping will not write to document @@ -202,14 +282,10 @@ public void onResponse(MLTaskResponse mlTaskResponse) { // document field as key, model field as value String newDocumentFieldName = entry.getKey(); String modelOutputFieldName = entry.getValue(); - if (ingestDocument.hasField(newDocumentFieldName)) { - throw new IllegalArgumentException( - "document already has field name " - + newDocumentFieldName - + ". Not allow to overwrite the same field name, please check output_map." - ); + if (!newOutputMapping.containsKey(newDocumentFieldName)) { + continue; } - appendFieldValue(modelTensorOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); + appendFieldValue(mlOutput, modelOutputFieldName, newDocumentFieldName, ingestDocument); } } batchPredictionListener.onResponse(null); @@ -305,63 +381,61 @@ private String getFieldPath(IngestDocument ingestDocument, String documentFieldN /** * Appends the model output value to the specified field in the IngestDocument without modifying the source. * - * @param modelTensorOutput the ModelTensorOutput containing the model output + * @param mlOutput the MLOutput containing the model output * @param modelOutputFieldName the name of the field in the model output * @param newDocumentFieldName the name of the field in the IngestDocument to append the value to * @param ingestDocument the IngestDocument to append the value to */ private void appendFieldValue( - ModelTensorOutput modelTensorOutput, + MLOutput mlOutput, String modelOutputFieldName, String newDocumentFieldName, IngestDocument ingestDocument ) { - Object modelOutputValue = null; - if (modelTensorOutput.getMlModelOutputs() != null && modelTensorOutput.getMlModelOutputs().size() > 0) { + if (mlOutput == null) { + throw new RuntimeException("model inference output is null"); + } - modelOutputValue = getModelOutputValue(modelTensorOutput, modelOutputFieldName, ignoreMissing); + Object modelOutputValue = getModelOutputValue(mlOutput, modelOutputFieldName, ignoreMissing, fullResponsePath); - Map ingestDocumentSourceAndMetaData = new HashMap<>(); - ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); - ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); - List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); + Map ingestDocumentSourceAndMetaData = new HashMap<>(); + ingestDocumentSourceAndMetaData.putAll(ingestDocument.getSourceAndMetadata()); + ingestDocumentSourceAndMetaData.put(IngestDocument.INGEST_KEY, ingestDocument.getIngestMetadata()); + List dotPathsInArray = writeNewDotPathForNestedObject(ingestDocumentSourceAndMetaData, newDocumentFieldName); - if (dotPathsInArray.size() == 1) { - ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + if (dotPathsInArray.size() == 1) { + ValueSource ingestValue = ValueSource.wrap(modelOutputValue, scriptService); + TemplateScript.Factory ingestField = ConfigurationUtils + .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); + } else { + if (!(modelOutputValue instanceof List)) { + throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); + } + List modelOutputValueArray = (List) modelOutputValue; + // check length of the prediction array to be the same of the document array + if (dotPathsInArray.size() != modelOutputValueArray.size()) { + throw new RuntimeException( + "the prediction field: " + + modelOutputFieldName + + " is an array in size of " + + modelOutputValueArray.size() + + " but the document field array from field " + + newDocumentFieldName + + " is in size of " + + dotPathsInArray.size() + ); + } + // Iterate over dotPathInArray + for (int i = 0; i < dotPathsInArray.size(); i++) { + String dotPathInArray = dotPathsInArray.get(i); + Object modelOutputValueInArray = modelOutputValueArray.get(i); + ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathsInArray.get(0), dotPathsInArray.get(0), scriptService); + .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } else { - if (!(modelOutputValue instanceof List)) { - throw new IllegalArgumentException("Model output is not an array, cannot assign to array in documents."); - } - List modelOutputValueArray = (List) modelOutputValue; - // check length of the prediction array to be the same of the document array - if (dotPathsInArray.size() != modelOutputValueArray.size()) { - throw new RuntimeException( - "the prediction field: " - + modelOutputFieldName - + " is an array in size of " - + modelOutputValueArray.size() - + " but the document field array from field " - + newDocumentFieldName - + " is in size of " - + dotPathsInArray.size() - ); - } - // Iterate over dotPathInArray - for (int i = 0; i < dotPathsInArray.size(); i++) { - String dotPathInArray = dotPathsInArray.get(i); - Object modelOutputValueInArray = modelOutputValueArray.get(i); - ValueSource ingestValue = ValueSource.wrap(modelOutputValueInArray, scriptService); - TemplateScript.Factory ingestField = ConfigurationUtils - .compileTemplate(TYPE, tag, dotPathInArray, dotPathInArray, scriptService); - ingestDocument.setFieldValue(ingestField, ingestValue, ignoreMissing); - } } - } else { - throw new RuntimeException("model inference output cannot be null"); } } @@ -374,6 +448,7 @@ public static class Factory implements Processor.Factory { private final ScriptService scriptService; private final Client client; + private final NamedXContentRegistry xContentRegistry; /** * Constructs a new instance of the Factory class. @@ -381,9 +456,10 @@ public static class Factory implements Processor.Factory { * @param scriptService the ScriptService instance to be used by the Factory * @param client the Client instance to be used by the Factory */ - public Factory(ScriptService scriptService, Client client) { + public Factory(ScriptService scriptService, Client client, NamedXContentRegistry xContentRegistry) { this.scriptService = scriptService; this.client = client; + this.xContentRegistry = xContentRegistry; } /** @@ -410,6 +486,14 @@ public MLInferenceIngestProcessor create( int maxPredictionTask = ConfigurationUtils .readIntProperty(TYPE, processorTag, config, MAX_PREDICTION_TASKS, DEFAULT_MAX_PREDICTION_TASKS); boolean ignoreMissing = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, IGNORE_MISSING, false); + boolean override = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, OVERRIDE, false); + String functionName = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, FUNCTION_NAME, FunctionName.REMOTE.name()); + String modelInput = ConfigurationUtils + .readStringProperty(TYPE, processorTag, config, MODEL_INPUT, "{ \"parameters\": ${ml_inference.parameters} }"); + boolean defaultValue = !functionName.equalsIgnoreCase("remote"); + boolean fullResponsePath = ConfigurationUtils.readBooleanProperty(TYPE, processorTag, config, FULL_RESPONSE_PATH, defaultValue); + boolean ignoreFailure = ConfigurationUtils .readBooleanProperty(TYPE, processorTag, config, ConfigurationUtils.IGNORE_FAILURE_KEY, false); // convert model config user input data structure to Map @@ -440,9 +524,14 @@ public MLInferenceIngestProcessor create( processorTag, description, ignoreMissing, + functionName, + fullResponsePath, ignoreFailure, + override, + modelInput, scriptService, - client + client, + xContentRegistry ); } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java index 1abc770d07..ff46c13f62 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/ModelExecutor.java @@ -5,17 +5,29 @@ package org.opensearch.ml.processor; +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; +import static org.opensearch.ml.common.utils.StringUtils.gson; +import static org.opensearch.ml.common.utils.StringUtils.isJson; + import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; +import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.stream.Collectors; +import org.apache.commons.text.StringSubstitutor; import org.opensearch.action.ActionRequest; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.common.xcontent.XContentType; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.core.xcontent.XContentParser; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.input.MLInput; +import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.model.ModelTensor; import org.opensearch.ml.common.output.model.ModelTensorOutput; import org.opensearch.ml.common.output.model.ModelTensors; @@ -45,17 +57,47 @@ public interface ModelExecutor { * @return an ActionRequest instance for remote model inference * @throws IllegalArgumentException if the input parameters are null */ - default ActionRequest getRemoteModelInferenceRequest(Map parameters, String modelId) { + default ActionRequest getMLModelInferenceRequest( + NamedXContentRegistry xContentRegistry, + Map parameters, + Map modelConfigs, + Map inputMappings, + String modelId, + String functionNameStr, + String modelInput + ) throws IOException { if (parameters == null) { throw new IllegalArgumentException("wrong input. The model input cannot be empty."); } - RemoteInferenceInputDataSet inputDataSet = RemoteInferenceInputDataSet.builder().parameters(parameters).build(); + FunctionName functionName = FunctionName.REMOTE; + if (functionNameStr != null) { + functionName = FunctionName.from(functionNameStr); + } + + Map inputParams = new HashMap<>(); + if (FunctionName.REMOTE == functionName) { + inputParams.put("parameters", StringUtils.toJson(parameters)); + } else { + inputParams.putAll(parameters); + } - MLInput mlInput = MLInput.builder().algorithm(FunctionName.REMOTE).inputDataset(inputDataSet).build(); + String payload = modelInput; + StringSubstitutor modelConfigSubstitutor = new StringSubstitutor(modelConfigs, "${model_config.", "}"); + payload = modelConfigSubstitutor.replace(payload); + StringSubstitutor inputMapSubstitutor = new StringSubstitutor(inputMappings, "${input_map.", "}"); + payload = inputMapSubstitutor.replace(payload); + StringSubstitutor parametersSubstitutor = new StringSubstitutor(inputParams, "${ml_inference.", "}"); + payload = parametersSubstitutor.replace(payload); - ActionRequest request = new MLPredictionTaskRequest(modelId, mlInput, null); + if (!isJson(payload)) { + throw new IllegalArgumentException("Invalid payload: " + payload); + } + XContentParser parser = XContentType.JSON.xContent().createParser(xContentRegistry, null, payload); + + ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.nextToken(), parser); + MLInput mlInput = MLInput.parse(parser, functionName.name()); - return request; + return new MLPredictionTaskRequest(modelId, mlInput); } @@ -74,7 +116,9 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m try { // getMlModelOutputs() returns a list or collection. // Adding null check for modelTensorOutput - if (modelTensorOutput != null && !modelTensorOutput.getMlModelOutputs().isEmpty()) { + if (modelTensorOutput != null + && modelTensorOutput.getMlModelOutputs() != null + && !modelTensorOutput.getMlModelOutputs().isEmpty()) { // getMlModelOutputs() returns a list of ModelTensors // accessing the first element. // TODO currently remote model only return single tensor, might need to processor multiple tensors later @@ -130,11 +174,35 @@ default Object getModelOutputValue(ModelTensorOutput modelTensorOutput, String m throw new RuntimeException("Model outputs are null or empty."); } } catch (Exception e) { - throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + throw new RuntimeException(e.getMessage()); } return modelOutputValue; } + default Object getModelOutputValue(MLOutput mlOutput, String modelOutputFieldName, boolean ignoreMissing, boolean fullResponsePath) { + try (XContentBuilder builder = XContentFactory.jsonBuilder()) { + String modelOutputJsonStr = mlOutput.toXContent(builder, ToXContent.EMPTY_PARAMS).toString(); + Map modelTensorOutputMap = gson.fromJson(modelOutputJsonStr, Map.class); + if (!fullResponsePath && mlOutput instanceof ModelTensorOutput) { + return getModelOutputValue((ModelTensorOutput) mlOutput, modelOutputFieldName, ignoreMissing); + } else if (modelOutputFieldName == null || modelTensorOutputMap == null) { + return modelTensorOutputMap; + } else { + try { + return JsonPath.parse(modelTensorOutputMap).read(modelOutputFieldName); + } catch (Exception e) { + if (ignoreMissing) { + return modelTensorOutputMap; + } else { + throw new IllegalArgumentException("model inference output cannot find such json path: " + modelOutputFieldName, e); + } + } + } + } catch (Exception e) { + throw new RuntimeException("An unexpected error occurred: " + e.getMessage()); + } + } + /** * Parses the data from the given ModelTensor and returns it as an Object. * The method handles different data types (integer, floating-point, string, and boolean) diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java index 577e8b8693..7ca077a82f 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorFactoryTests.java @@ -5,6 +5,9 @@ package org.opensearch.ml.processor; import static org.opensearch.ml.processor.InferenceProcessorAttributes.*; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FULL_RESPONSE_PATH; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.FUNCTION_NAME; +import static org.opensearch.ml.processor.MLInferenceIngestProcessor.MODEL_INPUT; import java.util.ArrayList; import java.util.HashMap; @@ -15,6 +18,7 @@ import org.mockito.Mock; import org.opensearch.OpenSearchParseException; import org.opensearch.client.Client; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.Processor; import org.opensearch.script.ScriptService; import org.opensearch.test.OpenSearchTestCase; @@ -25,10 +29,12 @@ public class MLInferenceIngestProcessorFactoryTests extends OpenSearchTestCase { private Client client; @Mock private ScriptService scriptService; + @Mock + private NamedXContentRegistry xContentRegistry; @Before public void init() { - factory = new MLInferenceIngestProcessor.Factory(scriptService, client); + factory = new MLInferenceIngestProcessor.Factory(scriptService, client, xContentRegistry); } public void testCreateRequiredFields() throws Exception { @@ -42,6 +48,34 @@ public void testCreateRequiredFields() throws Exception { assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); } + public void testCreateLocalModelProcessor() throws Exception { + Map registry = new HashMap<>(); + Map config = new HashMap<>(); + config.put(MODEL_ID, "model1"); + config.put(FUNCTION_NAME, "text_embedding"); + config.put(FULL_RESPONSE_PATH, true); + config.put(MODEL_INPUT, "{ \"text_docs\": ${ml_inference.text_docs} }"); + Map model_config = new HashMap<>(); + model_config.put("return_number", true); + config.put(MODEL_CONFIG, model_config); + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "text"); + inputMap.add(input); + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("text_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + config.put(INPUT_MAP, inputMap); + config.put(OUTPUT_MAP, outputMap); + config.put(MAX_PREDICTION_TASKS, 5); + String processorTag = randomAlphaOfLength(10); + MLInferenceIngestProcessor mLInferenceIngestProcessor = factory.create(registry, processorTag, null, config); + assertNotNull(mLInferenceIngestProcessor); + assertEquals(mLInferenceIngestProcessor.getTag(), processorTag); + assertEquals(mLInferenceIngestProcessor.getType(), MLInferenceIngestProcessor.TYPE); + } + public void testCreateNoFieldPresent() throws Exception { Map config = new HashMap<>(); try { diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java index d11cc213de..203392eb75 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceIngestProcessorTests.java @@ -9,10 +9,12 @@ import static org.mockito.Mockito.*; import static org.opensearch.ml.processor.MLInferenceIngestProcessor.DEFAULT_OUTPUT_FIELD_NAME; +import java.io.IOException; import java.nio.ByteBuffer; import java.time.ZonedDateTime; import java.util.ArrayList; import java.util.Arrays; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; @@ -26,6 +28,7 @@ import org.mockito.MockitoAnnotations; import org.opensearch.client.Client; import org.opensearch.core.action.ActionListener; +import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.ingest.IngestDocument; import org.opensearch.ml.common.dataset.remote.RemoteInferenceInputDataSet; import org.opensearch.ml.common.output.model.MLResultDataType; @@ -52,6 +55,9 @@ public class MLInferenceIngestProcessorTests extends OpenSearchTestCase { private ScriptService scriptService; @Mock private BiConsumer handler; + + @Mock + NamedXContentRegistry xContentRegistry; private static final String PROCESSOR_TAG = "inference"; private static final String DESCRIPTION = "inference_test"; private IngestDocument ingestDocument; @@ -74,30 +80,53 @@ public void setup() { } private MLInferenceIngestProcessor createMLInferenceProcessor( - String model_id, - Map model_config, - List> input_map, - List> output_map, + String modelId, + List> inputMaps, + List> outputMaps, + Map modelConfigMaps, boolean ignoreMissing, - boolean ignoreFailure + String functionName, + boolean fullResponsePath, + boolean ignoreFailure, + boolean override, + String modelInput ) { + functionName = functionName != null ? functionName : "remote"; + modelInput = modelInput != null ? modelInput : "{ \"parameters\": ${ml_inference.parameters} }"; + return new MLInferenceIngestProcessor( - model_id, - input_map, - output_map, - model_config, + modelId, + inputMaps, + outputMaps, + modelConfigMaps, RANDOM_MULTIPLIER, PROCESSOR_TAG, DESCRIPTION, ignoreMissing, + functionName, + fullResponsePath, ignoreFailure, + override, + modelInput, scriptService, - client + client, + xContentRegistry ); } public void testExecute_Exception() throws Exception { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); try { IngestDocument document = processor.execute(ingestDocument); } catch (UnsupportedOperationException e) { @@ -111,9 +140,20 @@ public void testExecute_Exception() throws Exception { */ public void testExecute_nestedObjectStringDocumentSuccess() { - List> inputMap = getInputMapsForNestedObjectChunks("chunks.chunk"); - - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -137,10 +177,21 @@ public void testExecute_nestedObjectStringDocumentSuccess() { * test nested object document with array of Map, * the value Object is a Map */ - public void testExecute_nestedObjectMapDocumentSuccess() { + public void testExecute_nestedObjectMapDocumentSuccess() throws IOException { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -181,9 +232,10 @@ public void testExecute_nestedObjectMapDocumentSuccess() { embedding_text.add("this is first"); embedding_text.add("this is second"); inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -224,10 +276,21 @@ public void testExecute_jsonPathWithMissingLeaves() { * test nested object document with array of Map, * the value Object is a also a nested object, */ - public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() { + public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSuccess() throws IOException { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3, 4))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -254,9 +317,10 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInOneFieldSucce embedding_text.add("this is third"); embedding_text.add("this is fourth"); inputParameters.put("inputs", modelExecutor.toString(embedding_text)); + String modelInput = "{ \"parameters\": ${ml_inference.parameters} }"; MLPredictionTaskRequest expectedRequest = (MLPredictionTaskRequest) modelExecutor - .getRemoteModelInferenceRequest(inputParameters, "model1"); + .getMLModelInferenceRequest(xContentRegistry, inputParameters, null, inputParameters, "model1", "remote", modelInput); MLPredictionTaskRequest actualRequest = argumentCaptor.getValue(); RemoteInferenceInputDataSet expectedRemoteInputDataset = (RemoteInferenceInputDataSet) expectedRequest @@ -278,7 +342,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArraySuccess( List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -311,7 +386,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -345,7 +431,18 @@ public void testExecute_nestedObjectAndNestedObjectDocumentOutputInArrayMissingL } public void testExecute_InferenceException() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); when(client.execute(any(), any())).thenThrow(new RuntimeException("Executing Model failed with exception")); try { processor.execute(ingestDocument, handler); @@ -355,7 +452,18 @@ public void testExecute_InferenceException() { } public void testExecute_InferenceOnFailure() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); RuntimeException inferenceFailure = new RuntimeException("Executing Model failed with exception"); doAnswer(invocation -> { @@ -375,7 +483,18 @@ public void testExecute_AppendFieldValueExceptionOnResponse() throws Exception { String originalOutPutFieldName = "response1"; output.put("text_embedding", originalOutPutFieldName); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); @@ -409,7 +528,18 @@ public void testExecute_whenInputFieldNotFound_ExceptionWithIgnoreMissingFalse() Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -429,7 +559,42 @@ public void testExecute_whenInputFieldNotFound_SuccessWithIgnoreMissingTrue() { Map output = new HashMap<>(); output.put("text_embedding", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); + + processor.execute(ingestDocument, handler); + } + + public void testExecute_localModelInputFieldNotFound_SuccessWithIgnoreMissingTrue() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + Map model_config = new HashMap<>(); + model_config.put("return_number", "true"); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "text_embedding", + true, + false, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); processor.execute(ingestDocument, handler); } @@ -447,7 +612,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingFalse() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -469,7 +645,18 @@ public void testExecute_whenEmptyInputField_ExceptionWithIgnoreMissingTrue() { Map model_config = new HashMap<>(); model_config.put("position_embedding_type", "absolute"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + true, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); @@ -491,7 +678,18 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin ObjectMapper mapper = mock(ObjectMapper.class); when(mapper.readValue(Mockito.anyString(), eq(Object.class))).thenThrow(JsonProcessingException.class); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", model_config, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + model_config, + false, + "remote", + false, + false, + false, + null + ); try { processor.execute(ingestDocument, handler); @@ -501,8 +699,56 @@ public void testExecute_IOExceptionWithIgnoreMissingFalse() throws JsonProcessin } public void testExecute_NoModelInput_Exception() { - MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor("model1", null, null, null, true, false); - MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor("model1", null, null, null, false, false); + MLInferenceIngestProcessor processorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); + MLInferenceIngestProcessor processorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "remote", + false, + false, + false, + null + ); + + MLInferenceIngestProcessor localModelProcessorIgnoreMissingFalse = createMLInferenceProcessor( + "model1", + null, + null, + null, + false, + "text_embedding", + false, + false, + false, + null + ); + + MLInferenceIngestProcessor localModelProcessorIgnoreMissingTrue = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "text_embedding", + false, + false, + false, + null + ); Map sourceAndMetadata = new HashMap<>(); IngestDocument emptyIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); @@ -517,10 +763,32 @@ public void testExecute_NoModelInput_Exception() { assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); } + try { + localModelProcessorIgnoreMissingTrue.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + try { + localModelProcessorIgnoreMissingFalse.execute(emptyIngestDocument, handler); + } catch (IllegalArgumentException e) { + assertEquals("wrong input. The model input cannot be empty.", e.getMessage()); + } + } public void testExecute_AppendModelOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", Arrays.asList(1, 2, 3))).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); @@ -543,7 +811,18 @@ public void testExecute_AppendModelOutputSuccess() { } public void testExecute_SingleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); Float[] value = new Float[] { 1.0f, 2.0f, 3.0f }; List outputs = new ArrayList<>(); @@ -578,7 +857,18 @@ public void testExecute_SingleTensorInDataOutputSuccess() { } public void testExecute_MultipleTensorInDataOutputSuccess() { - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, null, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + null, + null, + true, + "remote", + false, + false, + false, + null + ); List outputs = new ArrayList<>(); Float[] value = new Float[] { 1.0f }; @@ -640,7 +930,18 @@ public void testExecute_getModelOutputFieldWithFieldNameSuccess() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -671,7 +972,18 @@ public void testExecute_getModelOutputFieldWithDotPathSuccess() { output.put("language_identification", "response.language"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", List.of("en", "en"), "score", "0.9876"))) @@ -703,7 +1015,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathSuccess() { output.put("language_identification", "response.lan"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -733,7 +1056,18 @@ public void testExecute_getModelOutputFieldWithInvalidDotPathException() { output.put("response.lan", "language_identification"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -768,7 +1102,18 @@ public void testExecute_getModelOutputFieldInNestedWithInvalidDotPathException() output.put("chunks.*.chunk.text.*.context_embedding", "response.language1"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -803,7 +1148,18 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { output.put("key1", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, null, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + null, + outputMap, + null, + false, + "remote", + false, + false, + true, + null + ); ModelTensor modelTensor = ModelTensor .builder() .dataAsMap(ImmutableMap.of("response", ImmutableMap.of("language", "en", "score", "0.9876"))) @@ -818,17 +1174,13 @@ public void testExecute_getModelOutputFieldWithExistedFieldNameException() { }).when(client).execute(any(), any(), any()); processor.execute(ingestDocument, handler); - verify(handler) - .accept( - eq(null), - argThat( - exception -> exception - .getMessage() - .equals( - "document already has field name key1. Not allow to overwrite the same field name, please check output_map." - ) - ) - ); + + Map sourceAndMetadata = new HashMap<>(); + sourceAndMetadata.put("key1", ImmutableMap.of("language", "en", "score", "0.9876")); + sourceAndMetadata.put("key2", "value2"); + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(ingestDocument, ingestDocument1); } public void testExecute_documentNotExistedFieldNameException() { @@ -842,7 +1194,18 @@ public void testExecute_documentNotExistedFieldNameException() { output.put("classification", "response"); outputMap.add(output); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -852,7 +1215,18 @@ public void testExecute_documentNotExistedFieldNameException() { public void testExecute_nestedDocumentNotExistedFieldNameException() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context1"); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, null, false, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + null, + null, + false, + "remote", + false, + false, + false, + null + ); processor.execute(ingestDocument, handler); verify(handler) @@ -871,7 +1245,18 @@ public void testExecute_getModelOutputFieldDifferentLengthException() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -907,7 +1292,18 @@ public void testExecute_getModelOutputFieldDifferentLengthIgnoreFailureSuccess() List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ArrayList> modelPredictionOutput = new ArrayList<>(); modelPredictionOutput.add(Arrays.asList(1)); modelPredictionOutput.add(Arrays.asList(2)); @@ -937,7 +1333,18 @@ public void testExecute_getMlModelTensorsIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "remote", + false, + false, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -959,12 +1366,74 @@ public void testExecute_getMlModelTensorsIsNull() { } + public void testExecute_localMLModelTensorsIsNull() { + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + false, + "text_embedding", + true, + false, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + + verify(handler) + .accept( + eq(null), + argThat( + exception -> exception + .getMessage() + .equals( + "An unexpected error occurred: model inference output " + + "cannot find such json path: $.inference_results[0].output[0].data" + ) + ) + ); + + } + public void testExecute_getMlModelTensorsIsNullIgnoreFailure() { List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); doAnswer(invocation -> { @@ -985,7 +1454,18 @@ public void testExecute_modelTensorOutputIsNull() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, false); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + false, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -997,7 +1477,11 @@ public void testExecute_modelTensorOutputIsNull() { IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); processor.execute(nestedObjectIngestDocument, handler); - verify(handler).accept(eq(null), argThat(exception -> exception.getMessage().equals("model inference output cannot be null"))); + verify(handler) + .accept( + eq(null), + argThat(exception -> exception.getMessage().equals("An unexpected error occurred: Model outputs are null or empty.")) + ); } @@ -1006,7 +1490,18 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { List> outputMap = getOutputMapsForNestedObjectChunks(); - MLInferenceIngestProcessor processor = createMLInferenceProcessor("model1", null, inputMap, outputMap, true, true); + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + null + ); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); doAnswer(invocation -> { ActionListener actionListener = invocation.getArgument(2); @@ -1021,6 +1516,238 @@ public void testExecute_modelTensorOutputIsNullIgnoreFailureSuccess() { verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); } + /** + * Test processor configuration with nested object document + * and array of Map, where the value Object is a List + */ + public void testExecute_localModelSuccess() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "_ingest._value.title"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("_ingest._value.title_embedding", "$.inference_results[0].output[0].data"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output + List modelPredictionOutput = Arrays.asList(1, 2, 3, 4); + ModelTensor modelTensor = ModelTensor + .builder() + .dataAsMap( + ImmutableMap + .of( + "inference_results", + Arrays + .asList( + ImmutableMap + .of( + "output", + Arrays.asList(ImmutableMap.of("name", "sentence_embedding", "data", modelPredictionOutput)) + ) + ) + ) + ) + .build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + // Setting up the ingest document + Map sourceAndMetadata = new HashMap<>(); + List> books = new ArrayList<>(); + Map book1 = new HashMap<>(); + book1.put("title", Arrays.asList("first book")); + book1.put("description", "This is first book"); + Map book2 = new HashMap<>(); + book2.put("title", Arrays.asList("second book")); + book2.put("description", "This is second book"); + books.add(book1); + books.add(book2); + sourceAndMetadata.put("books", books); + + Map ingestMetadata = new HashMap<>(); + ingestMetadata.put("pipeline", "test_pipeline"); + ingestMetadata.put("timestamp", ZonedDateTime.now()); + Map ingestValue = new HashMap<>(); + ingestValue.put("title", Arrays.asList("first book")); + ingestValue.put("description", "This is first book"); + ingestMetadata.put("_value", ingestValue); + sourceAndMetadata.put("_ingest", ingestMetadata); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + processor.execute(nestedObjectIngestDocument, handler); + + // Validate the document + List> updatedBooks = new ArrayList<>(); + Map updatedBook1 = new HashMap<>(); + updatedBook1.put("title", Arrays.asList("first book")); + updatedBook1.put("description", "This is first book"); + updatedBook1.put("title_embedding", modelPredictionOutput); + Map updatedBook2 = new HashMap<>(); + updatedBook2.put("title", Arrays.asList("second book")); + updatedBook2.put("description", "This is second book"); + updatedBook2.put("title_embedding", modelPredictionOutput); + updatedBooks.add(updatedBook1); + updatedBooks.add(updatedBook2); + sourceAndMetadata.put("books", updatedBooks); + + IngestDocument ingestDocument1 = new IngestDocument(sourceAndMetadata, new HashMap<>()); + verify(handler).accept(eq(ingestDocument1), isNull()); + assertEquals(nestedObjectIngestDocument, ingestDocument1); + } + + public void testExecute_localSparseEncodingModelMultipleModelTensors() { + + // Processor configuration + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put("text_docs", "chunks.*.chunk.text.*.context"); + inputMap.add(input); + + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put("chunks.*.chunk.text.*.context_embedding", "$.inference_results.*.output.*.dataAsMap.response"); + outputMap.add(output); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model_1", + inputMap, + outputMap, + null, + true, + "sparse_encoding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + + // Mocking the model output with simple values + List> modelEmbeddings = new ArrayList<>(); + Map embedding = ImmutableMap.of("response", Arrays.asList(1.0, 2.0, 3.0, 4.0)); + for (int i = 1; i <= 4; i++) { + modelEmbeddings.add(embedding); + } + + List modelTensors = new ArrayList<>(); + for (Map embeddings : modelEmbeddings) { + modelTensors.add(ModelTensor.builder().dataAsMap(embeddings).build()); + } + + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput + .builder() + .mlModelOutputs(Collections.singletonList(ModelTensors.builder().mlModelTensors(modelTensors).build())) + .build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + IngestDocument ingestDocument = new IngestDocument(getNestedObjectWithAnotherNestedObjectSource(), new HashMap<>()); + processor.execute(ingestDocument, handler); + verify(handler).accept(eq(ingestDocument), isNull()); + + List> chunks = (List>) ingestDocument.getFieldValue("chunks", List.class); + + List> firstChunkTexts = (List>) ((Map) chunks.get(0).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(0).get("response"), firstChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(1).get("response"), firstChunkTexts.get(1).get("context_embedding")); + + List> secondChunkTexts = (List>) ((Map) chunks.get(1).get("chunk")) + .get("text"); + Assert.assertEquals(modelEmbeddings.get(2).get("response"), secondChunkTexts.get(0).get("context_embedding")); + Assert.assertEquals(modelEmbeddings.get(3).get("response"), secondChunkTexts.get(1).get("context_embedding")); + + } + + public void testExecute_localModelOutputIsNullIgnoreFailureSuccess() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "text_embedding", + true, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(null).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + + public void testExecute_localModelTensorsIsNullIgnoreFailure() { + List> inputMap = getInputMapsForNestedObjectChunks("chunks.*.chunk.text.*.context"); + + List> outputMap = getOutputMapsForNestedObjectChunks(); + + MLInferenceIngestProcessor processor = createMLInferenceProcessor( + "model1", + inputMap, + outputMap, + null, + true, + "remote", + false, + true, + false, + "{ \"text_docs\": ${ml_inference.text_docs} }" + ); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(null).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + Map sourceAndMetadata = getNestedObjectWithAnotherNestedObjectSource(); + + IngestDocument nestedObjectIngestDocument = new IngestDocument(sourceAndMetadata, new HashMap<>()); + + processor.execute(nestedObjectIngestDocument, handler); + verify(handler).accept(eq(nestedObjectIngestDocument), isNull()); + } + public void testParseGetDataInTensor_IntegerDataType() { ModelTensor mockTensor = mock(ModelTensor.class); when(mockTensor.getDataType()).thenReturn(MLResultDataType.INT8); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java index 1937b8c496..f8d623fc74 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceIngestProcessorIT.java @@ -5,6 +5,8 @@ package org.opensearch.ml.rest; +import static org.opensearch.ml.common.MLTask.MODEL_ID_FIELD; +import static org.opensearch.ml.utils.TestData.SENTENCE_TRANSFORMER_MODEL_URL; import static org.opensearch.ml.utils.TestHelper.makeRequest; import java.io.IOException; @@ -17,6 +19,12 @@ import org.junit.Before; import org.opensearch.client.Request; import org.opensearch.client.Response; +import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.MLTaskState; +import org.opensearch.ml.common.model.MLModelConfig; +import org.opensearch.ml.common.model.MLModelFormat; +import org.opensearch.ml.common.model.TextEmbeddingModelConfig; +import org.opensearch.ml.common.transport.register.MLRegisterModelInput; import org.opensearch.ml.utils.TestHelper; import com.google.common.collect.ImmutableList; @@ -26,6 +34,8 @@ public class RestMLInferenceIngestProcessorIT extends MLCommonsRestTestCase { private final String OPENAI_KEY = System.getenv("OPENAI_KEY"); private String openAIChatModelId; private String bedrockEmbeddingModelId; + + private String localModelId; private final String completionModelConnectorEntity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -350,6 +360,192 @@ public void testMLInferenceProcessorWithForEachProcessor() throws Exception { Assert.assertEquals(1536, embedding2.size()); } + public void testMLInferenceProcessorLocalModelObjectField() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"description\": \"test ml model ingest processor\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"diary\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"diary_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"diary_embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"id\": 1,\n" + + " \"diary\": [\"happy\",\"first day at school\"],\n" + + " \"weather\": \"rainy\"\n" + + " }"; + String index_name = "daily_index"; + createPipelineProcessor(createPipelineRequestBody, "diary_embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + List embeddingList = JsonPath.parse(document).read("_source.diary_embedding"); + Assert.assertEquals(2, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.diary_embedding[0]"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.42101282, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.diary_embedding[1]"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49191704, (Double) embedding2.get(0), 0.005); + } + + // TODO: add tests for other local model types such as sparse/cross encoders + public void testMLInferenceProcessorLocalModelNestedField() throws Exception { + + String taskId = registerModel(TestHelper.toJsonString(registerModelInput())); + waitForTask(taskId, MLTaskState.COMPLETED); + getTask(client(), taskId, response -> { + assertNotNull(response.get(MODEL_ID_FIELD)); + this.localModelId = (String) response.get(MODEL_ID_FIELD); + try { + String deployTaskID = deployModel(this.localModelId); + waitForTask(deployTaskID, MLTaskState.COMPLETED); + + getModel(client(), this.localModelId, model -> { assertEquals("DEPLOYED", model.get("model_state")); }); + } catch (IOException | InterruptedException e) { + throw new RuntimeException(e); + } + }); + + String createPipelineRequestBody = "{\n" + + " \"description\": \"ingest reviews and generate embedding\",\n" + + " \"processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"function_name\": \"text_embedding\",\n" + + " \"full_response_path\": true,\n" + + " \"model_id\": \"" + + this.localModelId + + "\",\n" + + " \"model_input\": \"{ \\\"text_docs\\\": ${ml_inference.text_docs} }\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"text_docs\": \"book.*.chunk.text.*.context\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"book.*.chunk.text.*.context_embedding\": \"$.inference_results.*.output.*.data\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": true,\n" + + " \"ignore_failure\": true\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String createIndexRequestBody = "{\n" + + " \"settings\": {\n" + + " \"index\": {\n" + + " \"default_pipeline\": \"embedding_pipeline\"\n" + + " }\n" + + " }\n" + + " }"; + String uploadDocumentRequestBody = "{\n" + + " \"book\": [\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the first part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"first chapter\",\n" + + " \"context\": \"this is the second part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " },\n" + + " {\n" + + " \"chunk\": {\n" + + " \"text\": [\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the third part\"\n" + + " },\n" + + " {\n" + + " \"chapter\": \"second chapter\",\n" + + " \"context\": \"this is the fourth part\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + String index_name = "book_index"; + createPipelineProcessor(createPipelineRequestBody, "embedding_pipeline"); + createIndex(index_name, createIndexRequestBody); + + uploadDocument(index_name, "1", uploadDocumentRequestBody); + Map document = getDocument(index_name, "1"); + + List embeddingList = JsonPath.parse(document).read("_source.book[*].chunk.text[*].context_embedding"); + Assert.assertEquals(4, embeddingList.size()); + + List embedding1 = JsonPath.parse(document).read("_source.book[0].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding1.size()); + Assert.assertEquals(0.48988956, (Double) embedding1.get(0), 0.005); + + List embedding2 = JsonPath.parse(document).read("_source.book[0].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding2.size()); + Assert.assertEquals(0.49552172, (Double) embedding2.get(0), 0.005); + + List embedding3 = JsonPath.parse(document).read("_source.book[1].chunk.text[0].context_embedding"); + Assert.assertEquals(768, embedding3.size()); + Assert.assertEquals(0.5004309, (Double) embedding3.get(0), 0.005); + + List embedding4 = JsonPath.parse(document).read("_source.book[1].chunk.text[1].context_embedding"); + Assert.assertEquals(768, embedding4.size()); + Assert.assertEquals(0.47907734, (Double) embedding4.get(0), 0.005); + } + protected void createPipelineProcessor(String requestBody, final String pipelineName) throws Exception { Response pipelineCreateResponse = TestHelper .makeRequest( @@ -378,7 +574,6 @@ protected void createIndex(String indexName, String requestBody) throws Exceptio protected void uploadDocument(final String index, final String docId, final String jsonBody) throws IOException { Request request = new Request("PUT", "/" + index + "/_doc/" + docId + "?refresh=true"); - request.setJsonEntity(jsonBody); client().performRequest(request); } @@ -390,4 +585,25 @@ protected Map getDocument(final String index, final String docId) throws Excepti return parseResponseToMap(docResponse); } + protected MLRegisterModelInput registerModelInput() throws IOException, InterruptedException { + + MLModelConfig modelConfig = TextEmbeddingModelConfig + .builder() + .modelType("bert") + .frameworkType(TextEmbeddingModelConfig.FrameworkType.SENTENCE_TRANSFORMERS) + .embeddingDimension(768) + .build(); + return MLRegisterModelInput + .builder() + .modelName("test_model_name") + .version("1.0.0") + .functionName(FunctionName.TEXT_EMBEDDING) + .modelFormat(MLModelFormat.TORCH_SCRIPT) + .modelConfig(modelConfig) + .url(SENTENCE_TRANSFORMER_MODEL_URL) + .deployModel(false) + .hashValue("e13b74006290a9d0f58c1376f9629d4ebc05a0f9385f40db837452b167ae9021") + .build(); + } + }