From 87d1dbf9525119325d7cbf89fed6a28ae02621e9 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Fri, 30 Aug 2024 12:14:42 -0700 Subject: [PATCH 1/3] fix custom prompt issues Signed-off-by: Mingshi Liu --- .../ml/common/connector/HttpConnector.java | 2 + .../ml/common/utils/StringUtils.java | 55 +++++- .../common/connector/HttpConnectorTest.java | 49 ++++- .../ml/common/utils/StringUtilsTest.java | 186 ++++++++++++++++++ .../MLInferenceSearchResponseProcessor.java | 4 +- ...InferenceSearchResponseProcessorTests.java | 100 +++++++++- ...tMLInferenceSearchResponseProcessorIT.java | 136 +++++++++++-- 7 files changed, 504 insertions(+), 28 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 287fbb8127..7bd94662c9 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -10,6 +10,7 @@ import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol; import static org.opensearch.ml.common.utils.StringUtils.getParameterMap; import static org.opensearch.ml.common.utils.StringUtils.isJson; +import static org.opensearch.ml.common.utils.StringUtils.parseParameters; import java.io.IOException; import java.time.Instant; @@ -322,6 +323,7 @@ public T createPayload(String action, Map parameters) { if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) { String payload = connectorAction.get().getRequestBody(); payload = fillNullParameters(parameters, payload); + parseParameters(parameters); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); if (!isJson(payload)) { diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index e71636e01b..86f6f46319 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -19,15 +19,14 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import org.apache.commons.lang3.BooleanUtils; -import org.json.JSONArray; -import org.json.JSONException; -import org.json.JSONObject; - import com.google.gson.Gson; import com.google.gson.JsonElement; import com.google.gson.JsonParser; import com.google.gson.JsonSyntaxException; +import org.apache.commons.lang3.BooleanUtils; +import org.json.JSONArray; +import org.json.JSONException; +import org.json.JSONObject; import lombok.extern.log4j.Log4j2; @@ -50,6 +49,7 @@ public class StringUtils { static { gson = new Gson(); } + public static final String TO_STRING_FUNCTION_NAME = ".toString()"; public static boolean isValidJsonString(String Json) { try { @@ -233,4 +233,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea return errorMessage + " Model ID: " + modelId; } } + + /** + * Collects the prefixes of the toString() method calls present in the values of the given map. + * + * @param map A map containing key-value pairs where the values may contain toString() method calls. + * @return A list of prefixes for the toString() method calls found in the map values. + */ + public static List collectToStringPrefixes(Map map) { + List prefixes = new ArrayList<>(); + for (String key : map.keySet()) { + String value = map.get(key); + if (value != null) { + Pattern pattern = Pattern.compile("\\$\\{(\\w+\\.\\w+)\\.toString\\(\\)\\}"); + Matcher matcher = pattern.matcher(value); + while (matcher.find()) { + String prefix = matcher.group(1); + prefixes.add(prefix.substring(prefix.lastIndexOf('.') + 1)); + } + } + } + return prefixes; + } + + /** + * Parses the given parameters map and processes the values containing toString() method calls. + * + * @param parameters A map containing key-value pairs where the values may contain toString() method calls. + * @return A new map with the processed values for the toString() method calls. + */ + public static Map parseParameters(Map parameters) { + if (parameters != null) { + List toStringParametersPrefixes = collectToStringPrefixes(parameters); + + if (!toStringParametersPrefixes.isEmpty()) { + for (String prefix : toStringParametersPrefixes) { + String value = parameters.get(prefix); + if (value != null) { + parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value)); + } + } + } + } + return parameters; + } + } diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index a84652791f..8aafa4e1fd 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -197,6 +197,53 @@ public void createPayloadWithString() { Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload); } + @Test + public void createPayloadWithInferenceProcessor() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + parameters.put("context", "[\"value 0\",\"value 1\",\"value 2\",\"value 3\",\"value 4\"]"); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + Assert + .assertEquals( + "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", + predictPayload + ); + } + + @Test + public void createPayloadWithInferenceProcessorContextInList() { + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); + Map parameters = new HashMap<>(); + + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + ArrayList NestedListOfDocuments = new ArrayList<>(); + NestedListOfDocuments.add("document2"); + listOfDocuments.add(toJson(NestedListOfDocuments)); + parameters.put("context", toJson(listOfDocuments)); + String predictPayload = connector.createPayload(PREDICT.name(), parameters); + connector.validatePayload(predictPayload); + Assert + .assertEquals( + "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", + predictPayload + ); + } + @Test public void createPayloadWithList() { String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; @@ -216,7 +263,7 @@ public void createPayloadWithNestedList() { String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); + parameters.put("prompt", "please replace \"\n\" with abc: ${parameters.context}"); ArrayList listOfDocuments = new ArrayList<>(); listOfDocuments.add("document1"); ArrayList NestedListOfDocuments = new ArrayList<>(); diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index cf112d6ca3..e1a8fd928f 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -6,7 +6,12 @@ package org.opensearch.ml.common.utils; import static org.junit.Assert.assertEquals; +import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME; +import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes; +import static org.opensearch.ml.common.utils.StringUtils.parseParameters; +import static org.opensearch.ml.common.utils.StringUtils.toJson; +import java.util.ArrayList; import java.util.Arrays; import java.util.HashMap; import java.util.HashSet; @@ -14,6 +19,7 @@ import java.util.Map; import java.util.Set; +import org.apache.commons.text.StringSubstitutor; import org.junit.Assert; import org.junit.Test; @@ -218,4 +224,184 @@ public void testGetErrorMessageWhenHiddenNull() { // Assert assertEquals(expected, result); } + + /** + * Tests the collectToStringPrefixes method with a map containing toString() method calls + * in the values. Verifies that the method correctly extracts the prefixes of the toString() + * method calls. + */ + @Test + public void testGetToStringPrefix() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + parameters.put("context", "${parameters.text.toString()}"); + + List prefixes = collectToStringPrefixes(parameters); + List expectPrefixes = new ArrayList<>(); + expectPrefixes.add("text"); + expectPrefixes.add("context"); + expectPrefixes.add("history"); + assertEquals(prefixes, expectPrefixes); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key. Verifies that the method correctly processes the list and adds + * the processed value to the map with the expected key. Also tests the string substitution + * using the processed values. + */ + @Test + public void testParseParametersListToString() { + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + System.out.println(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + System.out.println(requestBody); + assertEquals(requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\"]\"}"); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key, and the "prompt" value containing escaped characters. Verifies + * that the method correctly processes the list and adds the processed value to the map + * with the expected key. Also tests the string substitution using the processed values. + */ + @Test + public void testParseParametersListToStringWithEscapedPrompt() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + System.out.println(parameters.get("context" + TO_STRING_FUNCTION_NAME)); + System.out.println(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + System.out.println(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a nested list of strings as the + * value for the "context" key. Verifies that the method correctly processes the nested + * list and adds the processed value to the map with the expected key. Also tests the + * string substitution using the processed values. + */ + @Test + public void testParseParametersNestedListToString() { + Map parameters = new HashMap<>(); + parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}"); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + ArrayList NestedListOfDocuments = new ArrayList<>(); + NestedListOfDocuments.add("document2"); + listOfDocuments.add(toJson(NestedListOfDocuments)); + parameters.put("context", toJson(listOfDocuments)); + + parseParameters(parameters); + System.out.println(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + System.out.println(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a map of strings as the value + * for the "context" key. Verifies that the method correctly processes the map and adds + * the processed value to the map with the expected key. Also tests the string substitution + * using the processed values. + */ + @Test + public void testParseParametersMapToString() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + parameters.put("context", toJson(mapOfDocuments)); + parameters.put("history", "hello\n"); + parseParameters(parameters); + System.out.println(parameters); + assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"name\\\":\\\"John\\\"}"); + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + System.out.println(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: {\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a nested map of strings as the + * value for the "context" key. Verifies that the method correctly processes the nested + * map and adds the processed value to the map with the expected key. Also tests the + * string substitution using the processed values. + */ + @Test + public void testParseParametersNestedMapToString() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}" + ); + Map mapOfDocuments = new HashMap<>(); + mapOfDocuments.put("name", "John"); + Map nestedMapOfDocuments = new HashMap<>(); + nestedMapOfDocuments.put("city", "New York"); + mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); + parameters.put("context", toJson(mapOfDocuments)); + parameters.put("history", "hello\n"); + parseParameters(parameters); + System.out.println(parameters); + assertEquals( + parameters.get("context" + TO_STRING_FUNCTION_NAME), + "{\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"}" + ); + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); + System.out.println(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" + ); + } } diff --git a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java index f3da7c77bc..38e62528f3 100644 --- a/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java +++ b/plugin/src/main/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessor.java @@ -384,8 +384,8 @@ private void processPredictions( } } } - - modelParameters = StringUtils.getParameterMap(modelInputParameters); + Map modelParametersInString = StringUtils.getParameterMap(modelInputParameters); + modelParameters.putAll(modelParametersInString); Set inputMapKeys = new HashSet<>(modelParameters.keySet()); inputMapKeys.removeAll(modelConfigs.keySet()); diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 850e466ba6..6efa60b9c3 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -169,10 +169,10 @@ public void onFailure(Exception e) { /** * Tests create processor with one_to_one is true * with custom prompt - * with many to one prediction, 5 documents in hits are calling 1 prediction tasks + * with one to one prediction, 5 documents in hits are calling 5 prediction tasks * @throws Exception if an error occurs during the test */ - public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { + public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { String newDocumentField = "context"; String modelOutputField = "response"; @@ -202,6 +202,102 @@ public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { "{ \"prompt\": \"${model_config.prompt}\"}", client, TEST_XCONTENT_REGISTRY_FOR_QUERY, + true + ); + + SearchRequest request = getSearchRequest(); + String fieldName = "text"; + SearchResponse response = getSearchResponse(5, true, fieldName); + + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there are 1 values")).build(); + ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); + ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); + + doAnswer(invocation -> { + ActionListener actionListener = invocation.getArgument(2); + actionListener.onResponse(MLTaskResponse.builder().output(mlModelTensorOutput).build()); + return null; + }).when(client).execute(any(), any(), any()); + + ActionListener listener = new ActionListener<>() { + @Override + public void onResponse(SearchResponse newSearchResponse) { + assertEquals(newSearchResponse.getHits().getHits().length, 5); + assertEquals( + newSearchResponse.getHits().getHits()[0].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[1].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[2].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[3].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + assertEquals( + newSearchResponse.getHits().getHits()[4].getSourceAsMap().get(newDocumentField).toString(), + "there is 1 value" + ); + } + + @Override + public void onFailure(Exception e) { + throw new RuntimeException(e); + } + + }; + responseProcessor.processResponseAsync(request, response, responseContext, listener); + verify(client, times(5)).execute(any(), any(), any()); + } + + /** + * Tests create processor with one_to_one is false + * with custom prompt + * with many to one prediction, 5 documents in hits are calling 1 prediction tasks + * @throws Exception if an error occurs during the test + */ + public void testProcessResponseManyToOneWithCustomPrompt() throws Exception { + + String documentField = "text"; + String modelInputField = "context"; + List> inputMap = new ArrayList<>(); + Map input = new HashMap<>(); + input.put(modelInputField, documentField); + inputMap.add(input); + + String newDocumentField = "llm_response"; + String modelOutputField = "response"; + List> outputMap = new ArrayList<>(); + Map output = new HashMap<>(); + output.put(newDocumentField, modelOutputField); + outputMap.add(output); + Map modelConfig = new HashMap<>(); + modelConfig + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + MLInferenceSearchResponseProcessor responseProcessor = new MLInferenceSearchResponseProcessor( + "model1", + inputMap, + outputMap, + modelConfig, + DEFAULT_MAX_PREDICTION_TASKS, + PROCESSOR_TAG, + DESCRIPTION, + false, + "remote", + false, + false, + false, + "{ \"parameters\": ${ml_inference.parameters} }", + client, + TEST_XCONTENT_REGISTRY_FOR_QUERY, false ); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java index 64a9306691..9c82547623 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestMLInferenceSearchResponseProcessorIT.java @@ -36,6 +36,7 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC private String openAIChatModelId; private String bedrockEmbeddingModelId; private String localModelId; + private String bedrockClaudeModelId; private final String completionModelConnectorEntity = "{\n" + " \"name\": \"OpenAI text embedding model Connector\",\n" + " \"description\": \"The connector to public OpenAI text embedding model service\",\n" @@ -106,6 +107,47 @@ public class RestMLInferenceSearchResponseProcessorIT extends MLCommonsRestTestC + " ]\n" + "}"; + private final String bedrockClaudeModelConnectorEntity = "{\n" + + " \"name\": \"BedRock Claude instant-v1 Connector\",\n" + + " \"description\": \"The connector to bedrock for claude model\",\n" + + " \"version\": 1,\n" + + " \"protocol\": \"aws_sigv4\",\n" + + " \"parameters\": {\n" + + " \"region\": \"" + + GITHUB_CI_AWS_REGION + + "\",\n" + + " \"service_name\": \"bedrock\",\n" + + " \"anthropic_version\": \"bedrock-2023-05-31\",\n" + + " \"max_tokens_to_sample\": 8000,\n" + + " \"temperature\": 0.0001,\n" + + " \"response_filter\": \"$.completion\",\n" + + " \"stop_sequences\": [\"\\n\\nHuman:\",\"\\nObservation:\",\"\\n\\tObservation:\",\"\\nObservation\",\"\\n\\tObservation\",\"\\n\\nQuestion\"]\n" + + " },\n" + + " \"credential\": {\n" + + " \"access_key\": \"" + + AWS_ACCESS_KEY_ID + + "\",\n" + + " \"secret_key\": \"" + + AWS_SECRET_ACCESS_KEY + + "\",\n" + + " \"session_token\": \"" + + AWS_SESSION_TOKEN + + "\"\n" + + " },\n" + + " \"actions\": [\n" + + " {\n" + + " \"action_type\": \"predict\",\n" + + " \"method\": \"POST\",\n" + + " \"url\": \"https://bedrock-runtime.${parameters.region}.amazonaws.com/model/anthropic.claude-instant-v1/invoke\",\n" + + " \"headers\": {\n" + + " \"content-type\": \"application/json\",\n" + + " \"x-amz-content-sha256\": \"required\"\n" + + " },\n" + + " \"request_body\": \"{\\\"prompt\\\":\\\"${parameters.prompt}\\\", \\\"stop_sequences\\\": ${parameters.stop_sequences}, \\\"max_tokens_to_sample\\\":${parameters.max_tokens_to_sample}, \\\"temperature\\\":${parameters.temperature}, \\\"anthropic_version\\\":\\\"${parameters.anthropic_version}\\\" }\"\n" + + " }\n" + + " ]\n" + + "}"; + /** * Registers two remote models and creates an index and documents before running the tests. * @@ -119,7 +161,8 @@ public void setup() throws Exception { this.openAIChatModelId = registerRemoteModel(completionModelConnectorEntity, openAIChatModelName, true); String bedrockEmbeddingModelName = "bedrock embedding model " + randomAlphaOfLength(5); this.bedrockEmbeddingModelId = registerRemoteModel(bedrockEmbeddingModelConnectorEntity, bedrockEmbeddingModelName, true); - + String bedrockClaudeModelName = "bedrock claude model " + randomAlphaOfLength(5); + this.bedrockClaudeModelId = registerRemoteModel(bedrockClaudeModelConnectorEntity, bedrockClaudeModelName, true); String index_name = "daily_index"; String createIndexRequestBody = "{\n" + " \"mappings\": {\n" @@ -152,13 +195,14 @@ public void setup() throws Exception { /** * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. * It creates a search pipeline with the processor configured to use the remote model, - * performs a search using the pipeline, and verifies the inference results. - * + * performs a search using the pipeline, gathering search documents into context and added in a custom prompt + * Using a toString() in placeholder to specify the context needs to cast as string + * and verifies the inference results. * @throws Exception if any error occurs during the test */ - public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + public void testMLInferenceProcessorRemoteModelCustomPrompt() throws Exception { // Skip test if key is null - if (OPENAI_KEY == null) { + if (AWS_ACCESS_KEY_ID == null) { return; } String createPipelineRequestBody = "{\n" @@ -168,20 +212,26 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + " \"tag\": \"ml_inference\",\n" + " \"description\": \"This processor is going to run ml inference during search request\",\n" + " \"model_id\": \"" - + this.openAIChatModelId + + this.bedrockClaudeModelId + "\",\n" + + " \"function_name\": \"REMOTE\",\n" + " \"input_map\": [\n" + " {\n" - + " \"input\": \"weather\"\n" + + " \"context\": \"weather\"\n" + " }\n" + " ],\n" + " \"output_map\": [\n" + " {\n" - + " \"weather_embedding\": \"data[*].embedding\"\n" + + " \"llm_response\":\"$.response\"\n" + + " \n" + " }\n" + " ],\n" - + " \"ignore_missing\": false,\n" + + " \"model_config\": {\n" + + " \"prompt\":\"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"\n" + + " },\n" + + " \"ignore_missing\":false,\n" + " \"ignore_failure\": false\n" + + " \n" + " }\n" + " }\n" + " ]\n" @@ -190,18 +240,13 @@ public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; String index_name = "daily_index"; - String pipelineName = "weather_embedding_pipeline"; + String pipelineName = "qa_pipeline"; createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); Map response = searchWithPipeline(client(), index_name, pipelineName, query); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); - Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); - List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); - Assert.assertEquals(embeddingList.size(), 1536); - Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); - Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + System.out.println(response); + Assert.assertNotNull(JsonPath.parse(response).read("$.hits.hits[0]._source.llm_response")); + Assert.assertNotNull(JsonPath.parse(response).read("$.hits.hits[1]._source.llm_response")); } /** @@ -312,6 +357,61 @@ public void testMLInferenceProcessorRemoteModelNestedListField() throws Exceptio Assert.assertEquals((Double) embeddingList.get(1), -0.012508746, 0.005); } + /** + * Tests the MLInferenceSearchResponseProcessor with a remote model and an object field as input. + * It creates a search pipeline with the processor configured to use the remote model, + * performs a search using the pipeline, and verifies the inference results. + * + * @throws Exception if any error occurs during the test + */ + public void testMLInferenceProcessorRemoteModelObjectField() throws Exception { + // Skip test if key is null + if (OPENAI_KEY == null) { + return; + } + String createPipelineRequestBody = "{\n" + + " \"response_processors\": [\n" + + " {\n" + + " \"ml_inference\": {\n" + + " \"tag\": \"ml_inference\",\n" + + " \"description\": \"This processor is going to run ml inference during search request\",\n" + + " \"model_id\": \"" + + this.openAIChatModelId + + "\",\n" + + " \"input_map\": [\n" + + " {\n" + + " \"input\": \"weather\"\n" + + " }\n" + + " ],\n" + + " \"output_map\": [\n" + + " {\n" + + " \"weather_embedding\": \"data[*].embedding\"\n" + + " }\n" + + " ],\n" + + " \"ignore_missing\": false,\n" + + " \"ignore_failure\": false\n" + + " }\n" + + " }\n" + + " ]\n" + + "}"; + + String query = "{\"query\":{\"term\":{\"weather\":{\"value\":\"sunny\"}}}}"; + + String index_name = "daily_index"; + String pipelineName = "weather_embedding_pipeline"; + createSearchPipelineProcessor(createPipelineRequestBody, pipelineName); + + Map response = searchWithPipeline(client(), index_name, pipelineName, query); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary_embedding_size"), "1536"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.weather"), "sunny"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[0]"), "happy"); + Assert.assertEquals(JsonPath.parse(response).read("$.hits.hits[0]._source.diary[1]"), "first day at school"); + List embeddingList = (List) JsonPath.parse(response).read("$.hits.hits[0]._source.weather_embedding"); + Assert.assertEquals(embeddingList.size(), 1536); + Assert.assertEquals((Double) embeddingList.get(0), 0.00020525085, 0.005); + Assert.assertEquals((Double) embeddingList.get(1), -0.0071890163, 0.005); + } + /** * Tests the ML inference processor with a local model. * It registers, deploys, and gets a local model, creates a search pipeline with the ML inference processor From 120ad6acf00817cca4ae9215f5b8c6c8c4c63156 Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Fri, 30 Aug 2024 22:15:35 -0700 Subject: [PATCH 2/3] match any placeholder starts from parameters and end with toString() Signed-off-by: Mingshi Liu --- .../ml/common/utils/StringUtils.java | 13 +++--- .../ml/common/utils/StringUtilsTest.java | 41 ++++++++++++++----- ...InferenceSearchResponseProcessorTests.java | 2 +- 3 files changed, 38 insertions(+), 18 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java index 86f6f46319..57c24c22fd 100644 --- a/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java +++ b/common/src/main/java/org/opensearch/ml/common/utils/StringUtils.java @@ -19,15 +19,16 @@ import java.util.regex.Matcher; import java.util.regex.Pattern; -import com.google.gson.Gson; -import com.google.gson.JsonElement; -import com.google.gson.JsonParser; -import com.google.gson.JsonSyntaxException; import org.apache.commons.lang3.BooleanUtils; import org.json.JSONArray; import org.json.JSONException; import org.json.JSONObject; +import com.google.gson.Gson; +import com.google.gson.JsonElement; +import com.google.gson.JsonParser; +import com.google.gson.JsonSyntaxException; + import lombok.extern.log4j.Log4j2; @Log4j2 @@ -245,11 +246,11 @@ public static List collectToStringPrefixes(Map map) { for (String key : map.keySet()) { String value = map.get(key); if (value != null) { - Pattern pattern = Pattern.compile("\\$\\{(\\w+\\.\\w+)\\.toString\\(\\)\\}"); + Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}"); Matcher matcher = pattern.matcher(value); while (matcher.find()) { String prefix = matcher.group(1); - prefixes.add(prefix.substring(prefix.lastIndexOf('.') + 1)); + prefixes.add(prefix); } } } diff --git a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java index e1a8fd928f..a4b1460f39 100644 --- a/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java +++ b/common/src/test/java/org/opensearch/ml/common/utils/StringUtilsTest.java @@ -263,13 +263,11 @@ public void testParseParametersListToString() { parameters.put("context", toJson(listOfDocuments)); parseParameters(parameters); - System.out.println(parameters); assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); requestBody = substitutor.replace(requestBody); - System.out.println(requestBody); assertEquals(requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\"]\"}"); } @@ -292,14 +290,41 @@ public void testParseParametersListToStringWithEscapedPrompt() { parameters.put("context", toJson(listOfDocuments)); parseParameters(parameters); - System.out.println(parameters.get("context" + TO_STRING_FUNCTION_NAME)); - System.out.println(parameters); assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); requestBody = substitutor.replace(requestBody); - System.out.println(requestBody); + assertEquals( + requestBody, + "{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}" + ); + } + + /** + * Tests the parseParameters method with a map containing a list of strings as the value + * for the "context" key, and the "prompt" value containing escaped characters. Verifies + * that the method correctly processes the list and adds the processed value to the map + * with the expected key. Also tests the string substitution using the processed values. + */ + @Test + public void testParseParametersListToStringModelConfig() { + Map parameters = new HashMap<>(); + parameters + .put( + "prompt", + "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.model_config.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" + ); + ArrayList listOfDocuments = new ArrayList<>(); + listOfDocuments.add("document1"); + parameters.put("model_config.context", toJson(listOfDocuments)); + + parseParameters(parameters); + assertEquals(parameters.get("model_config.context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]"); + + String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; + StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); + requestBody = substitutor.replace(requestBody); assertEquals( requestBody, "{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}" @@ -324,13 +349,11 @@ public void testParseParametersNestedListToString() { parameters.put("context", toJson(listOfDocuments)); parseParameters(parameters); - System.out.println(parameters); assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]"); String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); requestBody = substitutor.replace(requestBody); - System.out.println(requestBody); assertEquals( requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]\"}" @@ -356,12 +379,10 @@ public void testParseParametersMapToString() { parameters.put("context", toJson(mapOfDocuments)); parameters.put("history", "hello\n"); parseParameters(parameters); - System.out.println(parameters); assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"name\\\":\\\"John\\\"}"); String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); requestBody = substitutor.replace(requestBody); - System.out.println(requestBody); assertEquals( requestBody, "{\"prompt\": \"answer question based on context: {\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" @@ -390,7 +411,6 @@ public void testParseParametersNestedMapToString() { parameters.put("context", toJson(mapOfDocuments)); parameters.put("history", "hello\n"); parseParameters(parameters); - System.out.println(parameters); assertEquals( parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"}" @@ -398,7 +418,6 @@ public void testParseParametersNestedMapToString() { String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); requestBody = substitutor.replace(requestBody); - System.out.println(requestBody); assertEquals( requestBody, "{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}" diff --git a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java index 6efa60b9c3..62b397f84b 100644 --- a/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java +++ b/plugin/src/test/java/org/opensearch/ml/processor/MLInferenceSearchResponseProcessorTests.java @@ -209,7 +209,7 @@ public void testProcessResponseOneToOneWithCustomPrompt() throws Exception { String fieldName = "text"; SearchResponse response = getSearchResponse(5, true, fieldName); - ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there are 1 values")).build(); + ModelTensor modelTensor = ModelTensor.builder().dataAsMap(ImmutableMap.of("response", "there is 1 value")).build(); ModelTensors modelTensors = ModelTensors.builder().mlModelTensors(Arrays.asList(modelTensor)).build(); ModelTensorOutput mlModelTensorOutput = ModelTensorOutput.builder().mlModelOutputs(Arrays.asList(modelTensors)).build(); From 247b955b5a014adf36dafb713cf6be15f8d53fcb Mon Sep 17 00:00:00 2001 From: Mingshi Liu Date: Sat, 31 Aug 2024 15:02:40 -0700 Subject: [PATCH 3/3] Revert "Support list in response body (#2811)" This reverts commit f64e3f3f Signed-off-by: Mingshi Liu --- .../ml/common/connector/HttpConnector.java | 29 +--- .../common/connector/HttpConnectorTest.java | 156 ------------------ 2 files changed, 2 insertions(+), 183 deletions(-) diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 7bd94662c9..edf26b954d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -326,38 +326,13 @@ public T createPayload(String action, Map parameters) { parseParameters(parameters); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); + if (!isJson(payload)) { - String payloadAfterEscape = connectorAction.get().getRequestBody(); - Map escapedParameters = escapeMapValues(parameters); - StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}"); - payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape); - if (!isJson(payloadAfterEscape)) { - throw new IllegalArgumentException("Invalid payload: " + payload); - } else { - payload = payloadAfterEscape; - } + throw new IllegalArgumentException("Invalid payload: " + payload); } return (T) payload; } return (T) parameters.get("http_body"); - - } - - public static Map escapeMapValues(Map parameters) { - Map escapedMap = new HashMap<>(); - if (parameters != null) { - for (Map.Entry entry : parameters.entrySet()) { - String key = entry.getKey(); - String value = entry.getValue(); - String escapedValue = escapeValue(value); - escapedMap.put(key, escapedValue); - } - } - return escapedMap; - } - - private static String escapeValue(String value) { - return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t"); } protected String fillNullParameters(Map parameters, String payload) { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 8aafa4e1fd..0115ac1376 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -6,7 +6,6 @@ package org.opensearch.ml.common.connector; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.utils.StringUtils.toJson; import java.io.IOException; import java.util.ArrayList; @@ -184,161 +183,6 @@ public void createPayload_InvalidJson() { connector.validatePayload(predictPayload); } - @Test - public void createPayloadWithString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - parameters.put("context", "document1"); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload); - } - - @Test - public void createPayloadWithInferenceProcessor() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters - .put( - "prompt", - "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" - ); - parameters.put("context", "[\"value 0\",\"value 1\",\"value 2\",\"value 3\",\"value 4\"]"); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert - .assertEquals( - "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", - predictPayload - ); - } - - @Test - public void createPayloadWithInferenceProcessorContextInList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters - .put( - "prompt", - "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" - ); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - NestedListOfDocuments.add("document2"); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert - .assertEquals( - "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", - predictPayload - ); - } - - @Test - public void createPayloadWithList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - listOfDocuments.add("document2"); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "please replace \"\n\" with abc: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - NestedListOfDocuments.add("document2"); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithMap() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedListOfMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - @Test public void createPayload() { HttpConnector connector = createHttpConnector();