Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use message history instead of event payload for conversation handler #2047

Merged
merged 15 commits into from
Sep 26, 2024
5 changes: 5 additions & 0 deletions .changeset/plenty-wombats-fry.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
---
'@aws-amplify/ai-constructs': minor
---

Use message history instead of event payload for conversational route
9 changes: 8 additions & 1 deletion packages/ai-constructs/API.md
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,14 @@ type ConversationTurnEvent = {
authorization: string;
};
};
messages: Array<ConversationMessage>;
messages?: Array<ConversationMessage>;
messageHistoryQuery: {
getQueryName: string;
getQueryInputTypeName: string;
listQueryName: string;
listQueryInputTypeName: string;
listQueryLimit?: number;
};
toolsConfiguration?: {
dataTools?: Array<ToolDefinition & {
graphqlRequestInputDescriptor: {
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,11 @@
import { describe, it, mock } from 'node:test';
import assert from 'node:assert';
import { ConversationTurnEvent, ExecutableTool, ToolDefinition } from './types';
import {
ConversationMessage,
ConversationTurnEvent,
ExecutableTool,
ToolDefinition,
} from './types';
import { BedrockConverseAdapter } from './bedrock_converse_adapter';
import {
BedrockRuntimeClient,
Expand All @@ -13,22 +18,19 @@ import {
} from '@aws-sdk/client-bedrock-runtime';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
import { randomBytes, randomUUID } from 'node:crypto';
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';

void describe('Bedrock converse adapter', () => {
const commonEvent: Readonly<ConversationTurnEvent> = {
conversationId: '',
currentMessageId: '',
graphqlApiEndpoint: '',
messages: [
{
role: 'user',
content: [
{
text: 'event message',
},
],
},
],
messageHistoryQuery: {
getQueryName: '',
getQueryInputTypeName: '',
listQueryName: '',
listQueryInputTypeName: '',
},
modelConfiguration: {
modelId: 'testModelId',
systemPrompt: 'testSystemPrompt',
Expand All @@ -46,6 +48,27 @@ void describe('Bedrock converse adapter', () => {
},
};

const messages: Array<ConversationMessage> = [
{
role: 'user',
content: [
{
text: 'event message',
},
],
},
];
const messageHistoryRetriever = new ConversationMessageHistoryRetriever(
commonEvent
);
const messageHistoryRetrieverMockGetEventMessages = mock.method(
messageHistoryRetriever,
'getMessageHistory',
() => {
return Promise.resolve(messages);
}
);

void it('calls bedrock to get conversation response', async () => {
const event: ConversationTurnEvent = {
...commonEvent,
Expand Down Expand Up @@ -78,7 +101,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand All @@ -90,7 +115,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
modelId: event.modelConfiguration.modelId,
inferenceConfig: event.modelConfiguration.inferenceConfiguration,
system: [
Expand Down Expand Up @@ -211,7 +236,8 @@ void describe('Bedrock converse adapter', () => {
event,
[additionalTool],
bedrockClient,
eventToolsProvider
eventToolsProvider,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -251,7 +277,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest1 = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput1: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest1.input, expectedBedrockInput1);
Expand All @@ -264,7 +290,7 @@ void describe('Bedrock converse adapter', () => {
);
const expectedBedrockInput2: ConverseCommandInput = {
messages: [
...(event.messages as Array<Message>),
...(messages as Array<Message>),
additionalToolUseBedrockResponse.output?.message,
{
role: 'user',
Expand Down Expand Up @@ -447,7 +473,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[tool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -543,7 +571,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[tool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(
Expand Down Expand Up @@ -645,7 +675,9 @@ void describe('Bedrock converse adapter', () => {
const responseContent = await new BedrockConverseAdapter(
event,
[additionalTool],
bedrockClient
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.deepStrictEqual(responseContent, [clientToolUseBlock]);
Expand Down Expand Up @@ -682,7 +714,7 @@ void describe('Bedrock converse adapter', () => {
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
.arguments[0] as unknown as ConverseCommand;
const expectedBedrockInput: ConverseCommandInput = {
messages: event.messages as Array<Message>,
messages: messages as Array<Message>,
...expectedBedrockInputCommonProperties,
};
assert.deepStrictEqual(bedrockRequest.input, expectedBedrockInput);
Expand All @@ -695,21 +727,27 @@ void describe('Bedrock converse adapter', () => {

const fakeImagePayload = randomBytes(32);

event.messages = [
{
role: 'user',
content: [
messageHistoryRetrieverMockGetEventMessages.mock.mockImplementationOnce(
() => {
return Promise.resolve([
{
image: {
format: 'png',
source: {
bytes: fakeImagePayload.toString('base64'),
id: '',
conversationId: '',
role: 'user',
content: [
{
image: {
format: 'png',
source: {
bytes: fakeImagePayload.toString('base64'),
},
},
},
},
],
},
],
},
];
]);
}
);

const bedrockClient = new BedrockRuntimeClient();
const bedrockResponse: ConverseCommandOutput = {
Expand All @@ -735,7 +773,13 @@ void describe('Bedrock converse adapter', () => {
Promise.resolve(bedrockResponse)
);

await new BedrockConverseAdapter(event, [], bedrockClient).askBedrock();
await new BedrockConverseAdapter(
event,
[],
bedrockClient,
undefined,
messageHistoryRetriever
).askBedrock();

assert.strictEqual(bedrockClientSendMock.mock.calls.length, 1);
const bedrockRequest = bedrockClientSendMock.mock.calls[0]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ import {
ToolDefinition,
} from './types.js';
import { ConversationTurnEventToolsProvider } from './event-tools-provider';
import { ConversationMessageHistoryRetriever } from './conversation_message_history_retriever';

/**
* This class is responsible for interacting with Bedrock Converse API
Expand All @@ -36,7 +37,10 @@ export class BedrockConverseAdapter {
private readonly bedrockClient: BedrockRuntimeClient = new BedrockRuntimeClient(
{ region: event.modelConfiguration.region }
),
eventToolsProvider = new ConversationTurnEventToolsProvider(event)
eventToolsProvider = new ConversationTurnEventToolsProvider(event),
private readonly messageHistoryRetriever = new ConversationMessageHistoryRetriever(
event
)
) {
this.executableTools = [
...eventToolsProvider.getEventTools(),
Expand Down Expand Up @@ -73,7 +77,8 @@ export class BedrockConverseAdapter {
const { modelId, systemPrompt, inferenceConfiguration } =
this.event.modelConfiguration;

const messages: Array<Message> = this.getEventMessagesAsBedrockMessages();
const messages: Array<Message> =
await this.getEventMessagesAsBedrockMessages();

let bedrockResponse: ConverseCommandOutput;
do {
Expand Down Expand Up @@ -124,9 +129,13 @@ export class BedrockConverseAdapter {
* 1. Makes a copy so that we don't mutate event.
* 2. Decodes Base64 encoded images.
*/
private getEventMessagesAsBedrockMessages = (): Array<Message> => {
private getEventMessagesAsBedrockMessages = async (): Promise<
Array<Message>
> => {
const messages: Array<Message> = [];
for (const message of this.event.messages) {
const eventMessages =
await this.messageHistoryRetriever.getMessageHistory();
for (const message of eventMessages) {
const messageContent: Array<ContentBlock> = [];
for (const contentElement of message.content) {
if (typeof contentElement.image?.source?.bytes === 'string') {
Expand Down
Loading
Loading