Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix: Gracefully handle error when generative_qa_parameters is not provided #3100

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants {
.boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic);

public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled.";

public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found."
+ " Please provide ext.generative_qa_parameters to proceed."
+ " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag";
}
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
package org.opensearch.searchpipelines.questionanswering.generative;

import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Duration;
import java.time.Instant;
Expand Down Expand Up @@ -126,6 +127,9 @@ public void processResponseAsync(
}

GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request);
if (params == null) {
throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);
}

Integer t = params.getTimeout();
if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.when;
import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG;

import java.time.Instant;
import java.util.Collections;
Expand Down Expand Up @@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception {
}));
}

public void testProcessResponseIllegalArgumentForNullParams() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG);

Client client = mock(Client.class);
Map<String, Object> config = new HashMap<>();
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model");
config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text"));

GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory(
client,
alwaysOn
).create(null, "tag", "desc", true, config, null);

ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class);
List<Interaction> chatHistory = List
.of(
new Interaction(
"0",
Instant.now(),
"1",
"question",
"",
"answer",
"foo",
Collections.singletonMap("meta data", "some meta")
)
);
doAnswer(invocation -> {
((ActionListener<List<Interaction>>) invocation.getArguments()[2]).onResponse(chatHistory);
return null;
}).when(memoryClient).getInteractions(any(), anyInt(), any());
processor.setMemoryClient(memoryClient);

SearchRequest request = new SearchRequest();
SearchSourceBuilder sourceBuilder = new SearchSourceBuilder();
GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder();
extBuilder.setParams(null);
request.source(sourceBuilder);
sourceBuilder.ext(List.of(extBuilder));

int numHits = 10;
SearchHit[] hitsArray = new SearchHit[numHits];
for (int i = 0; i < numHits; i++) {
XContentBuilder sourceContent = JsonXContent
.contentBuilder()
.startObject()
.field("_id", String.valueOf(i))
.field("text", "passage" + i)
.field("title", "This is the title for document " + i)
.endObject();
hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of());
hitsArray[i].sourceRef(BytesReference.bytes(sourceContent));
}

SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f);
SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0);
SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null);

Llm llm = mock(Llm.class);
processor.setLlm(llm);

processor
.processResponseAsync(
request,
response,
null,
ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {})
);
}

public void testProcessResponseIllegalArgument() throws Exception {
exceptionRule.expect(IllegalArgumentException.class);
exceptionRule.expectMessage("llm_model cannot be null.");
Expand Down
Loading