Skip to content

Commit

Permalink
Fix custom prompt substitute with List issue in ml inference search r…
Browse files Browse the repository at this point in the history
…esponse processor (#2871)
  • Loading branch information
mingshl authored Sep 2, 2024
1 parent 88fd3e7 commit 49d4a01
Show file tree
Hide file tree
Showing 7 changed files with 473 additions and 158 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -322,40 +323,16 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
parseParameters(parameters);
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);

if (!isJson(payload)) {
String payloadAfterEscape = connectorAction.get().getRequestBody();
Map<String, String> escapedParameters = escapeMapValues(parameters);
StringSubstitutor escapedSubstitutor = new StringSubstitutor(escapedParameters, "${parameters.", "}");
payloadAfterEscape = escapedSubstitutor.replace(payloadAfterEscape);
if (!isJson(payloadAfterEscape)) {
throw new IllegalArgumentException("Invalid payload: " + payload);
} else {
payload = payloadAfterEscape;
}
throw new IllegalArgumentException("Invalid payload: " + payload);
}
return (T) payload;
}
return (T) parameters.get("http_body");

}

public static Map<String, String> escapeMapValues(Map<String, String> parameters) {
Map<String, String> escapedMap = new HashMap<>();
if (parameters != null) {
for (Map.Entry<String, String> entry : parameters.entrySet()) {
String key = entry.getKey();
String value = entry.getValue();
String escapedValue = escapeValue(value);
escapedMap.put(key, escapedValue);
}
}
return escapedMap;
}

private static String escapeValue(String value) {
return value.replace("\\", "\\\\").replace("\"", "\\\"").replace("\n", "\\n").replace("\r", "\\r").replace("\t", "\\t");
}

protected String fillNullParameters(Map<String, String> parameters, String payload) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -233,4 +234,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{parameters\\.(.+?)\\.toString\\(\\)\\}");
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix);
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@
package org.opensearch.ml.common.connector;

import static org.opensearch.ml.common.connector.ConnectorAction.ActionType.PREDICT;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.io.IOException;
import java.util.ArrayList;
Expand Down Expand Up @@ -184,114 +183,6 @@ public void createPayload_InvalidJson() {
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters.put("prompt", "answer question based on context: ${parameters.context}");
parameters.put("context", "document1");
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
}

@Test
public void createPayloadWithList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
listOfDocuments.add("document2");
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithMap() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfString() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
parameters.put("context", toJson(mapOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayloadWithNestedListOfMapOfObject() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<Object> NestedListOfDocuments = new ArrayList<>();
Map<String, Object> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", nestedMapOfDocuments);
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
}

@Test
public void createPayload() {
HttpConnector connector = createHttpConnector();
Expand Down
Loading

0 comments on commit 49d4a01

Please sign in to comment.