Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support list in response body #2811

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -324,13 +324,38 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
payload = fillNullParameters(parameters, payload);
mingshl marked this conversation as resolved.
Show resolved Hide resolved
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
String payloadAfterEscape = connectorAction.get().getRequestBody();
Map<String, String> escapedParameters = escapeMapValues(parameters);
Copy link
Collaborator

@ylwu-amzn ylwu-amzn Aug 17, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If payload is wrong will escape all parameters again. Escape good parameters again may generate wrong result. For example, the original prompt is "please replace all \n of the context with --", if we escape again , it becomes "please replace all \\n of the context with --"

Copy link
Collaborator Author

@mingshl mingshl Aug 19, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

if a string is already partial escaped, meaning the string escaped for the next line symbol but doesn't escape for double quote, this change will still not fix the payload. Please let me know if there is better way to identify partial escape cases.

StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
if (!isJson(payloadAfterEscape)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else {
payload = payloadAfterEscape;
}
}
return (T) payload;
}
return (T) parameters.get("http_body");

}

public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
Map<String, String> escapedMap = new HashMap<>();
if (parameters != null) {
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeValue(value);
escapedMap.put(key, escapedValue);
}
}
return escapedMap;
}

private static String escapeValue(String value) {
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not use existing escape method? This escape method will replace "\n" in "This is a \new book" , which is not necessary. It may confuse user why same sentence will generate different embeddings after this PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it won't be a break change, because before this change, if it's an invalid payload, it will throw exception directly, it won't generate embedding at all.

}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -183,6 +184,114 @@ public void createPayload_InvalidJson() {
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters.put("prompt", "answer question based on context: ${parameters.context}");
parameters.put("context", "document1");
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
}

@Test
public void createPayloadWithList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
listOfDocuments.add("document2");
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithMap() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedListOfMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayload() {
HttpConnector connector = createHttpConnector();
Expand Down
Loading