From 767e73ba3c1670293504b993f9acaece1ee9f035 Mon Sep 17 00:00:00 2001 From: Kamil Sobol Date: Wed, 25 Sep 2024 16:34:03 -0700 Subject: [PATCH] process history --- .../conversation_message_history_retriever.ts | 59 +++++++++++++++++++ .../conversation_handler_project.ts | 1 + .../amplify/data/resource.ts | 1 + 3 files changed, 61 insertions(+) diff --git a/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts index f1bcbc1f4b..27c469e768 100644 --- a/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts +++ b/packages/ai-constructs/src/conversation/runtime/conversation_message_history_retriever.ts @@ -4,6 +4,8 @@ import { GraphqlRequestExecutor } from './graphql_request_executor'; export type ConversationHistoryMessageItem = ConversationMessage & { id: string; conversationId: string; + associatedUserMessageId?: string; + aiContext?: unknown; }; export type GetQueryInput = { @@ -40,6 +42,8 @@ export type ListQueryOutput = { const messageItemSelectionSet = ` id conversationId + associatedUserMessageId + aiContext role content { text @@ -121,6 +125,61 @@ export class ConversationMessageHistoryRetriever { messages.push(currentMessage); } + // Index assistant messages by corresponding user message. + const assistantMessageByUserMessageId: Map< + string, + ConversationHistoryMessageItem + > = new Map(); + messages.forEach((message) => { + if (message.role === 'assistant' && message.associatedUserMessageId) { + assistantMessageByUserMessageId.set( + message.associatedUserMessageId, + message + ); + } + }); + + // Reconcile history and inject aiContext + messages.reduce((acc, current) => { + // Bedrock expects that message history is user->assistant->user->assistant->... and so on. + // The chronological order doesn't assure this ordering if there were any concurrent messages sent. + // Therefore, conversation is ordered by user's messages only and corresponding assistant messages are inserted + // into right place regardless of their createdAt value. + // This algorithm assumes that GQL query returns messages sorted by createdAt. + if (current.role === 'assistant') { + // Initially, skip assistant messages, these might be out of chronological order. + return acc; + } + if ( + current.role === 'user' && + !assistantMessageByUserMessageId.has(current.id) && + current.id !== this.event.currentMessageId + ) { + // Skip user messages that didn't get answer from assistant yet. + // These might be still "in-flight", i.e. assistant is still working on them in separate invocation. + // Except current message, we want to process that one. + return acc; + } + const aiContext = current.aiContext; + const content = aiContext + ? [...current.content, { text: JSON.stringify(aiContext) }] + : current.content; + + acc.push({ role: current.role, content }); + + // Find and insert corresponding assistant message. + const correspondingAssistantMessage = assistantMessageByUserMessageId.get( + current.id + ); + if (correspondingAssistantMessage) { + acc.push({ + role: correspondingAssistantMessage.role, + content: correspondingAssistantMessage.content, + }); + } + return acc; + }, [] as Array); + return messages; }; diff --git a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts index a444fbf138..1ab7ab2175 100644 --- a/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts +++ b/packages/integration-tests/src/test-project-setup/conversation_handler_project.ts @@ -54,6 +54,7 @@ type ConversationTurnAppSyncResponse = { type CreateConversationMessageChatInput = ConversationMessage & { conversationId: string; id: string; + associatedUserMessageId?: string; }; const commonEventProperties = { diff --git a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts index c241f11c99..f698574310 100644 --- a/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts +++ b/packages/integration-tests/src/test-projects/conversation-handler/amplify/data/resource.ts @@ -97,6 +97,7 @@ const schema = a.schema({ ConversationMessageChat: a .model({ conversationId: a.id(), + associatedUserMessageId: a.id(), role: a.ref('MockConversationParticipantRole'), content: a.ref('MockContentBlock').array(), aiContext: a.json(),