Skip to content

Commit

Permalink
try add escape when payload is not json
Browse files Browse the repository at this point in the history
Signed-off-by: Mingshi Liu <[email protected]>

try help escape when the payload is not valid

Signed-off-by: Mingshi Liu <[email protected]>
  • Loading branch information
mingshl committed Aug 14, 2024
1 parent b77744e commit 83572b5
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 23 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -322,41 +322,43 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
parameters = formatArrayParameters(parameters);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
boolean isJson = isJson(payload);
if (!isJson) {
String manuallyFixedJson = connectorAction.get().getRequestBody();
Map<String, String> escapedParameters = escapeMapValues(parameters);
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
manuallyFixedJson = escapedSubstitutor.replace(manuallyFixedJson);
if (!isJson(manuallyFixedJson)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else {
payload = manuallyFixedJson;
}
}
return (T) payload;
}
return (T) parameters.get("http_body");
}

private Map<String,String> formatArrayParameters(Map<String, String> parameters) {
Map<String,String> newParameters = new HashMap<>();
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeJsonArrayIfNeeded(value);
newParameters.put(key,escapedValue);
}
return newParameters;
}
protected String escapeJsonArrayIfNeeded(String value) {
if (isJsonArray(value)) {
return value.replaceAll("([^\\\\])\"", "$1\\\\\"");

public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
Map<String, String> escapedMap = new HashMap<>();
if (parameters != null) {
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeValue(value);
escapedMap.put(key, escapedValue);
}
}
return value;
return escapedMap;
}

protected boolean isJsonArray(String value) {
Pattern jsonArrayPattern = Pattern.compile("^\\[.*\\]$");
return jsonArrayPattern.matcher(value).matches();
private static String escapeValue(String value) {
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
}


protected String fillNullParameters(Map<String, String> parameters, String payload) {
List<String> bodyParams = findStringParametersWithNullDefaultValue(payload);
String newPayload = payload;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -202,7 +203,7 @@ public void createPayloadWithList() {
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments= new ArrayList<>();
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
listOfDocuments.add("document2");
parameters.put("context", toJson(listOfDocuments));
Expand Down

0 comments on commit 83572b5

Please sign in to comment.