Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix custom prompt substitute with List issue #2871

Merged
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import static org.opensearch.ml.common.connector.ConnectorProtocols.validateProtocol;
import static org.opensearch.ml.common.utils.StringUtils.getParameterMap;
import static org.opensearch.ml.common.utils.StringUtils.isJson;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;

import java.io.IOException;
import java.time.Instant;
Expand Down Expand Up @@ -322,6 +323,7 @@ public <T> T createPayload(String action, Map<String, String> parameters) {
if (connectorAction.isPresent() && connectorAction.get().getRequestBody() != null) {
String payload = connectorAction.get().getRequestBody();
payload = fillNullParameters(parameters, payload);
parseParameters(parameters);
mingshl marked this conversation as resolved.
Show resolved Hide resolved
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
payload = substitutor.replace(payload);
if (!isJson(payload)) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
import java.util.regex.Matcher;
import java.util.regex.Pattern;

import org.apache.commons.lang3.BooleanUtils;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

import com.google.gson.Gson;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import com.google.gson.JsonSyntaxException;
import org.apache.commons.lang3.BooleanUtils;
import org.json.JSONArray;
import org.json.JSONException;
import org.json.JSONObject;

import lombok.extern.log4j.Log4j2;

Expand All @@ -50,6 +49,7 @@ public class StringUtils {
static {
gson = new Gson();
}
public static final String TO_STRING_FUNCTION_NAME = ".toString()";

public static boolean isValidJsonString(String Json) {
try {
Expand Down Expand Up @@ -233,4 +233,49 @@ public static String getErrorMessage(String errorMessage, String modelId, Boolea
return errorMessage + " Model ID: " + modelId;
}
}

/**
* Collects the prefixes of the toString() method calls present in the values of the given map.
*
* @param map A map containing key-value pairs where the values may contain toString() method calls.
* @return A list of prefixes for the toString() method calls found in the map values.
*/
public static List<String> collectToStringPrefixes(Map<String, String> map) {
List<String> prefixes = new ArrayList<>();
for (String key : map.keySet()) {
String value = map.get(key);
if (value != null) {
Pattern pattern = Pattern.compile("\\$\\{(\\w+\\.\\w+)\\.toString\\(\\)\\}");
mingshl marked this conversation as resolved.
Show resolved Hide resolved
Matcher matcher = pattern.matcher(value);
while (matcher.find()) {
String prefix = matcher.group(1);
prefixes.add(prefix.substring(prefix.lastIndexOf('.') + 1));
}
}
}
return prefixes;
}

/**
* Parses the given parameters map and processes the values containing toString() method calls.
*
* @param parameters A map containing key-value pairs where the values may contain toString() method calls.
* @return A new map with the processed values for the toString() method calls.
*/
public static Map<String, String> parseParameters(Map<String, String> parameters) {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Add prefix to method parameter list ? For connector, it will be parameters. For other usecase, it maybe different, like local model could be model_config

Suggested change
public static Map<String, String> parseParameters(Map<String, String> parameters) {
public static Map<String, String> parseParameters(String prefix, Map<String, String> parameters) {

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

all parameters will be prefix with parameters in here

StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");

I added the test to test when it's parameters.model_config.context, then it would be taking the prefix as model_config.context properly, please see this UT

public void testParseParametersListToStringModelConfig() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.model_config.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("model_config.context", toJson(listOfDocuments));
parseParameters(parameters);
System.out.println(parameters.get("context" + TO_STRING_FUNCTION_NAME));
System.out.println(parameters);
assertEquals(parameters.get("model_config.context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}"
);
}

if (parameters != null) {
List<String> toStringParametersPrefixes = collectToStringPrefixes(parameters);

if (!toStringParametersPrefixes.isEmpty()) {
for (String prefix : toStringParametersPrefixes) {
String value = parameters.get(prefix);
if (value != null) {
parameters.put(prefix + TO_STRING_FUNCTION_NAME, processTextDoc(value));
}
}
}
}
return parameters;
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,53 @@ public void createPayloadWithString() {
Assert.assertEquals("{\"prompt\": \"answer question based on context: document1\"}", predictPayload);
}

@Test
public void createPayloadWithInferenceProcessor() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
parameters.put("context", "[\"value 0\",\"value 1\",\"value 2\",\"value 3\",\"value 4\"]");
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert
.assertEquals(
"{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"value 0\\\",\\\"value 1\\\",\\\"value 2\\\",\\\"value 3\\\",\\\"value 4\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}",
predictPayload
);
}

@Test
public void createPayloadWithInferenceProcessorContextInList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();

parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));
String predictPayload = connector.createPayload(PREDICT.name(), parameters);
connector.validatePayload(predictPayload);
Assert
.assertEquals(
"{\"prompt\": \"\\\\n\\\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]. \\\\n\\\\n Human: please summarize the documents \\\\n\\\\n Assistant:\"}",
predictPayload
);
}

@Test
public void createPayloadWithList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
Expand All @@ -216,7 +263,7 @@ public void createPayloadWithNestedList() {
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
HttpConnector connector = createHttpConnectorWithRequestBody(requestBody);
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context}");
parameters.put("prompt", "please replace \"\n\" with abc: ${parameters.context}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,20 @@
package org.opensearch.ml.common.utils;

import static org.junit.Assert.assertEquals;
import static org.opensearch.ml.common.utils.StringUtils.TO_STRING_FUNCTION_NAME;
import static org.opensearch.ml.common.utils.StringUtils.collectToStringPrefixes;
import static org.opensearch.ml.common.utils.StringUtils.parseParameters;
import static org.opensearch.ml.common.utils.StringUtils.toJson;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;

import org.apache.commons.text.StringSubstitutor;
import org.junit.Assert;
import org.junit.Test;

Expand Down Expand Up @@ -218,4 +224,184 @@ public void testGetErrorMessageWhenHiddenNull() {
// Assert
assertEquals(expected, result);
}

/**
* Tests the collectToStringPrefixes method with a map containing toString() method calls
* in the values. Verifies that the method correctly extracts the prefixes of the toString()
* method calls.
*/
@Test
public void testGetToStringPrefix() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
parameters.put("context", "${parameters.text.toString()}");

List<String> prefixes = collectToStringPrefixes(parameters);
List<String> expectPrefixes = new ArrayList<>();
expectPrefixes.add("text");
expectPrefixes.add("context");
expectPrefixes.add("history");
assertEquals(prefixes, expectPrefixes);
}

/**
* Tests the parseParameters method with a map containing a list of strings as the value
* for the "context" key. Verifies that the method correctly processes the list and adds
* the processed value to the map with the expected key. Also tests the string substitution
* using the processed values.
*/
@Test
public void testParseParametersListToString() {
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
System.out.println(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(requestBody, "{\"prompt\": \"answer question based on context: [\\\"document1\\\"]\"}");
}

/**
* Tests the parseParameters method with a map containing a list of strings as the value
* for the "context" key, and the "prompt" value containing escaped characters. Verifies
* that the method correctly processes the list and adds the processed value to the map
* with the expected key. Also tests the string substitution using the processed values.
*/
@Test
public void testParseParametersListToStringWithEscapedPrompt() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: ${parameters.context.toString()}. \\n\\n Human: please summarize the documents \\n\\n Assistant:"
);
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
System.out.println(parameters.get("context" + TO_STRING_FUNCTION_NAME));
mingshl marked this conversation as resolved.
Show resolved Hide resolved
System.out.println(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"\\n\\nHuman: You are a professional data analyst. You will always answer question based on the given context first. If the answer is not directly shown in the context, you will analyze the data and find the answer. If you don't know the answer, just say I don't know. Context: [\\\"document1\\\"]. \\n\\n Human: please summarize the documents \\n\\n Assistant:\"}"
);
}

/**
* Tests the parseParameters method with a map containing a nested list of strings as the
* value for the "context" key. Verifies that the method correctly processes the nested
* list and adds the processed value to the map with the expected key. Also tests the
* string substitution using the processed values.
*/
@Test
public void testParseParametersNestedListToString() {
Map<String, String> parameters = new HashMap<>();
parameters.put("prompt", "answer question based on context: ${parameters.context.toString()}");
ArrayList<String> listOfDocuments = new ArrayList<>();
listOfDocuments.add("document1");
ArrayList<String> NestedListOfDocuments = new ArrayList<>();
NestedListOfDocuments.add("document2");
listOfDocuments.add(toJson(NestedListOfDocuments));
parameters.put("context", toJson(listOfDocuments));

parseParameters(parameters);
System.out.println(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "[\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]");

String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: [\\\"document1\\\",\\\"[\\\\\\\"document2\\\\\\\"]\\\"]\"}"
);
}

/**
* Tests the parseParameters method with a map containing a map of strings as the value
* for the "context" key. Verifies that the method correctly processes the map and adds
* the processed value to the map with the expected key. Also tests the string substitution
* using the processed values.
*/
@Test
public void testParseParametersMapToString() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
parameters.put("context", toJson(mapOfDocuments));
parameters.put("history", "hello\n");
parseParameters(parameters);
System.out.println(parameters);
assertEquals(parameters.get("context" + TO_STRING_FUNCTION_NAME), "{\\\"name\\\":\\\"John\\\"}");
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: {\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}"
);
}

/**
* Tests the parseParameters method with a map containing a nested map of strings as the
* value for the "context" key. Verifies that the method correctly processes the nested
* map and adds the processed value to the map with the expected key. Also tests the
* string substitution using the processed values.
*/
@Test
public void testParseParametersNestedMapToString() {
Map<String, String> parameters = new HashMap<>();
parameters
.put(
"prompt",
"answer question based on context: ${parameters.context.toString()} and conversation history based on history: ${parameters.history.toString()}"
);
Map<String, String> mapOfDocuments = new HashMap<>();
mapOfDocuments.put("name", "John");
Map<String, String> nestedMapOfDocuments = new HashMap<>();
nestedMapOfDocuments.put("city", "New York");
mapOfDocuments.put("hometown", toJson(nestedMapOfDocuments));
parameters.put("context", toJson(mapOfDocuments));
parameters.put("history", "hello\n");
parseParameters(parameters);
System.out.println(parameters);
assertEquals(
parameters.get("context" + TO_STRING_FUNCTION_NAME),
"{\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"}"
);
String requestBody = "{\"prompt\": \"${parameters.prompt}\"}";
StringSubstitutor substitutor = new StringSubstitutor(parameters, "${parameters.", "}");
requestBody = substitutor.replace(requestBody);
System.out.println(requestBody);
assertEquals(
requestBody,
"{\"prompt\": \"answer question based on context: {\\\"hometown\\\":\\\"{\\\\\\\"city\\\\\\\":\\\\\\\"New York\\\\\\\"}\\\",\\\"name\\\":\\\"John\\\"} and conversation history based on history: hello\\n\"}"
);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -384,8 +384,8 @@ private void processPredictions(
}
}
}

modelParameters = StringUtils.getParameterMap(modelInputParameters);
Map<String, String> modelParametersInString = StringUtils.getParameterMap(modelInputParameters);
modelParameters.putAll(modelParametersInString);

Set<String> inputMapKeys = new HashSet<>(modelParameters.keySet());
inputMapKeys.removeAll(modelConfigs.keySet());
Expand Down
Loading
Loading