Skip to content

Commit

Permalink
Support list in response body (#2811)
Browse files Browse the repository at this point in the history
(cherry picked from commit f64e3f3)
  • Loading branch information
mingshl authored and github-actions[bot] committed Aug 15, 2024
1 parent 8d998ab commit 9758291
Show file tree
Hide file tree
Showing 2 changed files with 136 additions and 2 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,38 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
payload = fillNullParameters(parameters, payload);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
String payloadAfterEscape = connectorAction.get().getRequestBody();
Map<String, String> escapedParameters = escapeMapValues(parameters);
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
if (!isJson(payloadAfterEscape)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else {
payload = payloadAfterEscape;
}
}
return (T) payload;
}
return (T) parameters.get("http_body");

}

public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
Map<String, String> escapedMap = new HashMap<>();
if (parameters != null) {
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeValue(value);
escapedMap.put(key, escapedValue);
}
}
return escapedMap;
}

private static String escapeValue(String value) {
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -183,6 +184,114 @@ public void createPayload_InvalidJson() {
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters.put("prompt", "answer question based on context: ${parameters.context}");
parameters.put("context", "document1");
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
}

@Test
public void createPayloadWithList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
listOfDocuments.add("document2");
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithMap() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedListOfMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayload() {
HttpConnector connector = createHttpConnector();
Expand Down

0 comments on commit 9758291

Please sign in to comment.