From 67425a109e652dd74f63b48fb604b3b55804086f Mon Sep 17 00:00:00 2001 From: Pavan Yekbote Date: Mon, 14 Oct 2024 17:58:27 -0700 Subject: [PATCH] Fix: Gracefully handle error when generative_qa_parameters is not provided (#3100) * fix: gracefully handle error when generative_qa_parameters is not provided Signed-off-by: Pavan Yekbote * fix: spotless apply Signed-off-by: Pavan Yekbote * docs: adding documentation link to error message Signed-off-by: Pavan Yekbote * tests: adding UT to test null params Signed-off-by: Pavan Yekbote --------- Signed-off-by: Pavan Yekbote (cherry picked from commit 0f7481e3fba9833548074660fc79b69aef3ce527) --- .../GenerativeQAProcessorConstants.java | 4 ++ .../GenerativeQAResponseProcessor.java | 4 ++ .../GenerativeQAResponseProcessorTests.java | 72 +++++++++++++++++++ 3 files changed, 80 insertions(+) diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java index ff71cadba2..7b3cf07db8 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAProcessorConstants.java @@ -43,4 +43,8 @@ public class GenerativeQAProcessorConstants { .boolSetting("plugins.ml_commons.rag_pipeline_feature_enabled", true, Setting.Property.NodeScope, Setting.Property.Dynamic); public static final String FEATURE_NOT_ENABLED_ERROR_MSG = RAG_PIPELINE_FEATURE_ENABLED.getKey() + " is not enabled."; + + public static final String RAG_NULL_GEN_QA_PARAMS_ERROR_MSG = "generative_qa_parameters not found." + + " Please provide ext.generative_qa_parameters to proceed." + + " For more info, refer: https://opensearch.org/docs/latest/search-plugins/conversational-search/#step-6-use-the-pipeline-for-rag"; } diff --git a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java index 7b1814c2a5..5ac106fb51 100644 --- a/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java +++ b/search-processors/src/main/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessor.java @@ -18,6 +18,7 @@ package org.opensearch.searchpipelines.questionanswering.generative; import static org.opensearch.ingest.ConfigurationUtils.newConfigurationException; +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; import java.time.Duration; import java.time.Instant; @@ -126,6 +127,9 @@ public void processResponseAsync( } GenerativeQAParameters params = GenerativeQAParamUtil.getGenerativeQAParameters(request); + if (params == null) { + throw new IllegalArgumentException(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); + } Integer t = params.getTimeout(); if (t == null || t == GenerativeQAParameters.SIZE_NULL_VALUE) { diff --git a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java index a89b5c1731..4295ab450e 100644 --- a/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java +++ b/search-processors/src/test/java/org/opensearch/searchpipelines/questionanswering/generative/GenerativeQAResponseProcessorTests.java @@ -24,6 +24,7 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.verify; import static org.mockito.Mockito.when; +import static org.opensearch.searchpipelines.questionanswering.generative.GenerativeQAProcessorConstants.RAG_NULL_GEN_QA_PARAMS_ERROR_MSG; import java.time.Instant; import java.util.Collections; @@ -646,6 +647,77 @@ public void testProcessResponseNullValueInteractions() throws Exception { })); } + public void testProcessResponseIllegalArgumentForNullParams() throws Exception { + exceptionRule.expect(IllegalArgumentException.class); + exceptionRule.expectMessage(RAG_NULL_GEN_QA_PARAMS_ERROR_MSG); + + Client client = mock(Client.class); + Map config = new HashMap<>(); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_MODEL_ID, "dummy-model"); + config.put(GenerativeQAProcessorConstants.CONFIG_NAME_CONTEXT_FIELD_LIST, List.of("text")); + + GenerativeQAResponseProcessor processor = (GenerativeQAResponseProcessor) new GenerativeQAResponseProcessor.Factory( + client, + alwaysOn + ).create(null, "tag", "desc", true, config, null); + + ConversationalMemoryClient memoryClient = mock(ConversationalMemoryClient.class); + List chatHistory = List + .of( + new Interaction( + "0", + Instant.now(), + "1", + "question", + "", + "answer", + "foo", + Collections.singletonMap("meta data", "some meta") + ) + ); + doAnswer(invocation -> { + ((ActionListener>) invocation.getArguments()[2]).onResponse(chatHistory); + return null; + }).when(memoryClient).getInteractions(any(), anyInt(), any()); + processor.setMemoryClient(memoryClient); + + SearchRequest request = new SearchRequest(); + SearchSourceBuilder sourceBuilder = new SearchSourceBuilder(); + GenerativeQAParamExtBuilder extBuilder = new GenerativeQAParamExtBuilder(); + extBuilder.setParams(null); + request.source(sourceBuilder); + sourceBuilder.ext(List.of(extBuilder)); + + int numHits = 10; + SearchHit[] hitsArray = new SearchHit[numHits]; + for (int i = 0; i < numHits; i++) { + XContentBuilder sourceContent = JsonXContent + .contentBuilder() + .startObject() + .field("_id", String.valueOf(i)) + .field("text", "passage" + i) + .field("title", "This is the title for document " + i) + .endObject(); + hitsArray[i] = new SearchHit(i, "doc" + i, Map.of(), Map.of()); + hitsArray[i].sourceRef(BytesReference.bytes(sourceContent)); + } + + SearchHits searchHits = new SearchHits(hitsArray, null, 1.0f); + SearchResponseSections internal = new SearchResponseSections(searchHits, null, null, false, false, null, 0); + SearchResponse response = new SearchResponse(internal, null, 1, 1, 0, 1, null, null, null); + + Llm llm = mock(Llm.class); + processor.setLlm(llm); + + processor + .processResponseAsync( + request, + response, + null, + ActionListener.wrap(r -> { assertTrue(r instanceof GenerativeSearchResponse); }, e -> {}) + ); + } + public void testProcessResponseIllegalArgument() throws Exception { exceptionRule.expect(IllegalArgumentException.class); exceptionRule.expectMessage("llm_model cannot be null.");