diff --git a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java index 7bd94662c9..edf26b954d 100644 --- a/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java +++ b/common/src/main/java/org/opensearch/ml/common/connector/HttpConnector.java @@ -326,38 +326,13 @@ public T createPayload(String action, Map parameters) { parseParameters(parameters); StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}"); payload = substitutor.replace(payload); + if (!isJson(payload)) { - String payloadAfterEscape = connectorAction.get().getRequestBody(); - Map escapedParameters = escapeMapValues(parameters); - StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}"); - payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape); - if (!isJson(payloadAfterEscape)) { - throw new IllegalArgumentException("Invalid payload: " + payload); - } else { - payload = payloadAfterEscape; - } + throw new IllegalArgumentException("Invalid payload: " + payload); } return (T) payload; } return (T) parameters.get("http_body"); - - } - - public static Map escapeMapValues(Map parameters) { - Map escapedMap = new HashMap<>(); - if (parameters != null) { - for (Map.Entry entry : parameters.entrySet()) { - String key = entry.getKey(); - String value = entry.getValue(); - String escapedValue = escapeValue(value); - escapedMap.put(key, escapedValue); - } - } - return escapedMap; - } - - private static String escapeValue(String value) { - return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t"); } protected String fillNullParameters(Map parameters, String payload) { diff --git a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java index 8aafa4e1fd..0115ac1376 100644 --- a/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java +++ b/common/src/test/java/org/opensearch/ml/common/connector/HttpConnectorTest.java @@ -6,7 +6,6 @@ package org.opensearch.ml.common.connector; import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT; -import static org.opensearch.ml.common.utils.StringUtils.toJson; import java.io.IOException; import java.util.ArrayList; @@ -184,161 +183,6 @@ public void createPayload_InvalidJson() { connector.validatePayload(predictPayload); } - @Test - public void createPayloadWithString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - parameters.put("context", "document1"); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload); - } - - @Test - public void createPayloadWithInferenceProcessor() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters - .put( - "prompt", - "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" - ); - parameters.put("context", "[\"value 0\",\"value 1\",\"value 2\",\"value 3\",\"value 4\"]"); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert - .assertEquals( - "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", - predictPayload - ); - } - - @Test - public void createPayloadWithInferenceProcessorContextInList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - - parameters - .put( - "prompt", - "\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:" - ); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - NestedListOfDocuments.add("document2"); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - Assert - .assertEquals( - "{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}", - predictPayload - ); - } - - @Test - public void createPayloadWithList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - listOfDocuments.add("document2"); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedList() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "please replace \"\n\" with abc: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - NestedListOfDocuments.add("document2"); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithMap() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfString() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments)); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - parameters.put("context", toJson(mapOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - - @Test - public void createPayloadWithNestedListOfMapOfObject() { - String requestBody = "{\"prompt\": \"${parameters.prompt}\"}"; - HttpConnector connector = createHttpConnectorWithRequestBody(requestBody); - Map parameters = new HashMap<>(); - parameters.put("prompt", "answer question based on context: ${parameters.context}"); - ArrayList listOfDocuments = new ArrayList<>(); - listOfDocuments.add("document1"); - ArrayList NestedListOfDocuments = new ArrayList<>(); - Map mapOfDocuments = new HashMap<>(); - mapOfDocuments.put("name", "John"); - Map nestedMapOfDocuments = new HashMap<>(); - nestedMapOfDocuments.put("city", "New York"); - mapOfDocuments.put("hometown", nestedMapOfDocuments); - listOfDocuments.add(toJson(NestedListOfDocuments)); - parameters.put("context", toJson(listOfDocuments)); - String predictPayload = connector.createPayload(PREDICT.name(), parameters); - connector.validatePayload(predictPayload); - } - @Test public void createPayload() { HttpConnector connector = createHttpConnector();