Skip to content

Commit

Permalink
♻️ refactor: refactor agent runtime implement of stream and ZHIPU pro…
Browse files Browse the repository at this point in the history
…vider (lobehub#4323)

* ♻️ refactor: refactor stream implement

* ♻️ refactor: refactor the Zhipu AI to the OpenAICompatible

* 🎨 chore: clean wenxin code

* 🐛 fix: fix ZHIPU payload
  • Loading branch information
arvinxx authored Oct 11, 2024
1 parent e68f81f commit 59661a1
Show file tree
Hide file tree
Showing 14 changed files with 59 additions and 186 deletions.
2 changes: 1 addition & 1 deletion src/libs/agent-runtime/AgentRuntime.ts
Original file line number Diff line number Diff line change
Expand Up @@ -174,7 +174,7 @@ class AgentRuntime {
}

case ModelProvider.ZhiPu: {
runtimeModel = await LobeZhipuAI.fromAPIKey(params.zhipu);
runtimeModel = new LobeZhipuAI(params.zhipu);
break;
}

Expand Down
4 changes: 2 additions & 2 deletions src/libs/agent-runtime/google/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ import { ModelProvider } from '../types/type';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { GoogleGenerativeAIStream, googleGenAIResultToStream } from '../utils/streams';
import { GoogleGenerativeAIStream, convertIterableToStream } from '../utils/streams';
import { parseDataUri } from '../utils/uriParser';

enum HarmCategory {
Expand Down Expand Up @@ -97,7 +97,7 @@ export class LobeGoogleAI implements LobeRuntimeAI {
tools: this.buildGoogleTools(payload.tools),
});

const googleStream = googleGenAIResultToStream(geminiStreamResult);
const googleStream = convertIterableToStream(geminiStreamResult.stream);
const [prod, useForDebug] = googleStream.tee();

if (process.env.DEBUG_GOOGLE_CHAT_COMPLETION === '1') {
Expand Down
10 changes: 2 additions & 8 deletions src/libs/agent-runtime/utils/streams/anthropic.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
import Anthropic from '@anthropic-ai/sdk';
import type { Stream } from '@anthropic-ai/sdk/streaming';
import { readableFromAsyncIterable } from 'ai';

import { ChatStreamCallbacks } from '../../types';
import {
StreamProtocolChunk,
StreamProtocolToolCallChunk,
StreamStack,
StreamToolCallChunkData,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
} from './protocol';
Expand Down Expand Up @@ -96,20 +96,14 @@ export const transformAnthropicStream = (
}
};

const chatStreamable = async function* (stream: AsyncIterable<Anthropic.MessageStreamEvent>) {
for await (const response of stream) {
yield response;
}
};

export const AnthropicStream = (
stream: Stream<Anthropic.MessageStreamEvent> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const streamStack: StreamStack = { id: '' };

const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);

return readableStream
.pipeThrough(createSSEProtocolTransformer(transformAnthropicStream, streamStack))
Expand Down
10 changes: 2 additions & 8 deletions src/libs/agent-runtime/utils/streams/azureOpenai.ts
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
import { ChatCompletions, ChatCompletionsFunctionToolCall } from '@azure/openai';
import { readableFromAsyncIterable } from 'ai';
import OpenAI from 'openai';
import type { Stream } from 'openai/streaming';

Expand All @@ -9,6 +8,7 @@ import {
StreamProtocolToolCallChunk,
StreamStack,
StreamToolCallChunkData,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
} from './protocol';
Expand Down Expand Up @@ -69,19 +69,13 @@ const transformOpenAIStream = (chunk: ChatCompletions, stack: StreamStack): Stre
};
};

const chatStreamable = async function* (stream: AsyncIterable<OpenAI.ChatCompletionChunk>) {
for await (const response of stream) {
yield response;
}
};

export const AzureOpenAIStream = (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const stack: StreamStack = { id: '' };
const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);

return readableStream
.pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, stack))
Expand Down
13 changes: 1 addition & 12 deletions src/libs/agent-runtime/utils/streams/google-ai.ts
Original file line number Diff line number Diff line change
@@ -1,8 +1,4 @@
import {
EnhancedGenerateContentResponse,
GenerateContentStreamResult,
} from '@google/generative-ai';
import { readableFromAsyncIterable } from 'ai';
import { EnhancedGenerateContentResponse } from '@google/generative-ai';

import { nanoid } from '@/utils/uuid';

Expand All @@ -11,7 +7,6 @@ import {
StreamProtocolChunk,
StreamStack,
StreamToolCallChunkData,
chatStreamable,
createCallbacksTransformer,
createSSEProtocolTransformer,
generateToolCallId,
Expand Down Expand Up @@ -50,12 +45,6 @@ const transformGoogleGenerativeAIStream = (
};
};

// only use for debug
export const googleGenAIResultToStream = (stream: GenerateContentStreamResult) => {
// make the response to the streamable format
return readableFromAsyncIterable(chatStreamable(stream.stream));
};

export const GoogleGenerativeAIStream = (
rawStream: ReadableStream<EnhancedGenerateContentResponse>,
callbacks?: ChatStreamCallbacks,
Expand Down
10 changes: 2 additions & 8 deletions src/libs/agent-runtime/utils/streams/ollama.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { readableFromAsyncIterable } from 'ai';
import { ChatResponse } from 'ollama/browser';

import { ChatStreamCallbacks } from '@/libs/agent-runtime';
Expand All @@ -7,6 +6,7 @@ import { nanoid } from '@/utils/uuid';
import {
StreamProtocolChunk,
StreamStack,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
} from './protocol';
Expand All @@ -20,19 +20,13 @@ const transformOllamaStream = (chunk: ChatResponse, stack: StreamStack): StreamP
return { data: chunk.message.content, id: stack.id, type: 'text' };
};

const chatStreamable = async function* (stream: AsyncIterable<ChatResponse>) {
for await (const response of stream) {
yield response;
}
};

export const OllamaStream = (
res: AsyncIterable<ChatResponse>,
cb?: ChatStreamCallbacks,
): ReadableStream<string> => {
const streamStack: StreamStack = { id: 'chat_' + nanoid() };

return readableFromAsyncIterable(chatStreamable(res))
return convertIterableToStream(res)
.pipeThrough(createSSEProtocolTransformer(transformOllamaStream, streamStack))
.pipeThrough(createCallbacksTransformer(cb));
};
10 changes: 2 additions & 8 deletions src/libs/agent-runtime/utils/streams/openai.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { readableFromAsyncIterable } from 'ai';
import OpenAI from 'openai';
import type { Stream } from 'openai/streaming';

Expand All @@ -10,6 +9,7 @@ import {
StreamProtocolToolCallChunk,
StreamStack,
StreamToolCallChunkData,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
generateToolCallId,
Expand Down Expand Up @@ -105,20 +105,14 @@ export const transformOpenAIStream = (
}
};

const chatStreamable = async function* (stream: AsyncIterable<OpenAI.ChatCompletionChunk>) {
for await (const response of stream) {
yield response;
}
};

export const OpenAIStream = (
stream: Stream<OpenAI.ChatCompletionChunk> | ReadableStream,
callbacks?: ChatStreamCallbacks,
) => {
const streamStack: StreamStack = { id: '' };

const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);

return readableStream
.pipeThrough(createSSEProtocolTransformer(transformOpenAIStream, streamStack))
Expand Down
7 changes: 7 additions & 0 deletions src/libs/agent-runtime/utils/streams/protocol.ts
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import { readableFromAsyncIterable } from 'ai';

import { ChatStreamCallbacks } from '@/libs/agent-runtime';

export interface StreamStack {
Expand Down Expand Up @@ -42,6 +44,11 @@ export const chatStreamable = async function* <T>(stream: AsyncIterable<T>) {
}
};

// make the response to the streamable format
export const convertIterableToStream = <T>(stream: AsyncIterable<T>) => {
return readableFromAsyncIterable(chatStreamable(stream));
};

export const createSSEProtocolTransformer = (
transformer: (chunk: any, stack: StreamStack) => StreamProtocolChunk,
streamStack?: StreamStack,
Expand Down
5 changes: 2 additions & 3 deletions src/libs/agent-runtime/utils/streams/qwen.ts
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@
import { readableFromAsyncIterable } from 'ai';
import { ChatCompletionContentPartText } from 'ai/prompts';
import OpenAI from 'openai';
import { ChatCompletionContentPart } from 'openai/resources/index.mjs';
Expand All @@ -9,7 +8,7 @@ import {
StreamProtocolChunk,
StreamProtocolToolCallChunk,
StreamToolCallChunkData,
chatStreamable,
convertIterableToStream,
createCallbacksTransformer,
createSSEProtocolTransformer,
generateToolCallId,
Expand Down Expand Up @@ -86,7 +85,7 @@ export const QwenAIStream = (
callbacks?: ChatStreamCallbacks,
) => {
const readableStream =
stream instanceof ReadableStream ? stream : readableFromAsyncIterable(chatStreamable(stream));
stream instanceof ReadableStream ? stream : convertIterableToStream(stream);

return readableStream
.pipeThrough(createSSEProtocolTransformer(transformQwenStream))
Expand Down
10 changes: 7 additions & 3 deletions src/libs/agent-runtime/utils/streams/wenxin.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,9 @@ import { describe, expect, it, vi } from 'vitest';

import * as uuidModule from '@/utils/uuid';

import { convertIterableToStream } from '../../utils/streams/protocol';
import { ChatResp } from '../../wenxin/type';
import { WenxinResultToStream, WenxinStream } from './wenxin';
import { WenxinStream } from './wenxin';

const dataStream = [
{
Expand Down Expand Up @@ -95,7 +96,7 @@ describe('WenxinStream', () => {
},
};

const stream = WenxinResultToStream(mockWenxinStream);
const stream = convertIterableToStream(mockWenxinStream);

const onStartMock = vi.fn();
const onTextMock = vi.fn();
Expand Down Expand Up @@ -142,7 +143,10 @@ describe('WenxinStream', () => {

expect(onStartMock).toHaveBeenCalledTimes(1);
expect(onTextMock).toHaveBeenNthCalledWith(1, '"当然可以,"');
expect(onTextMock).toHaveBeenNthCalledWith(2, '"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"');
expect(onTextMock).toHaveBeenNthCalledWith(
2,
'"以下是一些建议的自驾游路线,它们涵盖了各种不同的风景和文化体验:\\n\\n1. **西安-敦煌历史文化之旅**:\\n\\n\\n\\t* 路线:西安"',
);
expect(onTokenMock).toHaveBeenCalledTimes(6);
expect(onCompletionMock).toHaveBeenCalledTimes(1);
});
Expand Down
8 changes: 0 additions & 8 deletions src/libs/agent-runtime/utils/streams/wenxin.ts
Original file line number Diff line number Diff line change
@@ -1,13 +1,10 @@
import { readableFromAsyncIterable } from 'ai';

import { ChatStreamCallbacks } from '@/libs/agent-runtime';
import { nanoid } from '@/utils/uuid';

import { ChatResp } from '../../wenxin/type';
import {
StreamProtocolChunk,
StreamStack,
chatStreamable,
createCallbacksTransformer,
createSSEProtocolTransformer,
} from './protocol';
Expand All @@ -29,11 +26,6 @@ const transformERNIEBotStream = (chunk: ChatResp): StreamProtocolChunk => {
};
};

export const WenxinResultToStream = (stream: AsyncIterable<ChatResp>) => {
// make the response to the streamable format
return readableFromAsyncIterable(chatStreamable(stream));
};

export const WenxinStream = (
rawStream: ReadableStream<ChatResp>,
callbacks?: ChatStreamCallbacks,
Expand Down
5 changes: 3 additions & 2 deletions src/libs/agent-runtime/wenxin/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@ import { ChatCompetitionOptions, ChatStreamPayload } from '../types';
import { AgentRuntimeError } from '../utils/createError';
import { debugStream } from '../utils/debugStream';
import { StreamingResponse } from '../utils/response';
import { WenxinResultToStream, WenxinStream } from '../utils/streams/wenxin';
import { convertIterableToStream } from '../utils/streams';
import { WenxinStream } from '../utils/streams/wenxin';
import { ChatResp } from './type';

interface ChatErrorCode {
Expand Down Expand Up @@ -46,7 +47,7 @@ export class LobeWenxinAI implements LobeRuntimeAI {
payload.model,
);

const wenxinStream = WenxinResultToStream(result as AsyncIterable<ChatResp>);
const wenxinStream = convertIterableToStream(result as AsyncIterable<ChatResp>);

const [prod, useForDebug] = wenxinStream.tee();

Expand Down
31 changes: 7 additions & 24 deletions src/libs/agent-runtime/zhipu/index.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
import { OpenAI } from 'openai';
import { Mock, afterEach, beforeEach, describe, expect, it, vi } from 'vitest';

import { ChatStreamCallbacks, LobeOpenAI } from '@/libs/agent-runtime';
import { ChatStreamCallbacks, LobeOpenAI, LobeOpenAICompatibleRuntime } from '@/libs/agent-runtime';
import * as debugStreamModule from '@/libs/agent-runtime/utils/debugStream';

import * as authTokenModule from './authToken';
Expand All @@ -24,28 +24,11 @@ describe('LobeZhipuAI', () => {
vi.restoreAllMocks();
});

describe('fromAPIKey', () => {
it('should correctly initialize with an API key', async () => {
const lobeZhipuAI = await LobeZhipuAI.fromAPIKey({ apiKey: 'test_api_key' });
expect(lobeZhipuAI).toBeInstanceOf(LobeZhipuAI);
expect(lobeZhipuAI.baseURL).toEqual('https://open.bigmodel.cn/api/paas/v4');
});

it('should throw an error if API key is invalid', async () => {
vi.spyOn(authTokenModule, 'generateApiToken').mockRejectedValue(new Error('Invalid API Key'));
try {
await LobeZhipuAI.fromAPIKey({ apiKey: 'asd' });
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
});
});

describe('chat', () => {
let instance: LobeZhipuAI;
let instance: LobeOpenAICompatibleRuntime;

beforeEach(async () => {
instance = await LobeZhipuAI.fromAPIKey({
instance = new LobeZhipuAI({
apiKey: 'test_api_key',
});

Expand Down Expand Up @@ -131,9 +114,9 @@ describe('LobeZhipuAI', () => {
const calledWithParams = spyOn.mock.calls[0][0];

expect(calledWithParams.messages[1].content).toEqual([{ type: 'text', text: 'Hello again' }]);
expect(calledWithParams.temperature).toBeUndefined(); // temperature 0 should be undefined
expect(calledWithParams.temperature).toBe(0); // temperature 0 should be undefined
expect((calledWithParams as any).do_sample).toBeTruthy(); // temperature 0 should be undefined
expect(calledWithParams.top_p).toEqual(0.99); // top_p should be transformed correctly
expect(calledWithParams.top_p).toEqual(1); // top_p should be transformed correctly
});

describe('Error', () => {
Expand Down Expand Up @@ -175,7 +158,7 @@ describe('LobeZhipuAI', () => {

it('should throw AgentRuntimeError with NoOpenAIAPIKey if no apiKey is provided', async () => {
try {
await LobeZhipuAI.fromAPIKey({ apiKey: '' });
new LobeZhipuAI({ apiKey: '' });
} catch (e) {
expect(e).toEqual({ errorType: invalidErrorType });
}
Expand Down Expand Up @@ -221,7 +204,7 @@ describe('LobeZhipuAI', () => {
};
const apiError = new OpenAI.APIError(400, errorInfo, 'module error', {});

instance = await LobeZhipuAI.fromAPIKey({
instance = new LobeZhipuAI({
apiKey: 'test',

baseURL: 'https://abc.com/v2',
Expand Down
Loading

0 comments on commit 59661a1

Please sign in to comment.