From 3d9ec6849dfd9325f32a6d6039b4891f83191abc Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 25 Sep 2024 10:23:26 -0700 Subject: [PATCH 1/5] add `metadata` field to agent messages Currently only set when the default chat handler is used. --- .../jupyter_ai/callback_handlers/__init__.py | 6 +++++ .../jupyter_ai/callback_handlers/metadata.py | 25 +++++++++++++++++++ .../jupyter_ai/chat_handlers/default.py | 20 ++++++++++----- packages/jupyter-ai/jupyter_ai/models.py | 19 +++++++++++++- packages/jupyter-ai/src/chat_handler.ts | 1 + .../src/components/pending-messages.tsx | 3 ++- packages/jupyter-ai/src/handler.ts | 2 ++ 7 files changed, 68 insertions(+), 8 deletions(-) create mode 100644 packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py create mode 100644 packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py new file mode 100644 index 00000000..5590fdcd --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py @@ -0,0 +1,6 @@ +""" +Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`. +Not to be confused with Jupyter AI chat handlers. +""" + +from .metadata import MetadataCallbackHandler \ No newline at end of file diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py new file mode 100644 index 00000000..249739aa --- /dev/null +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -0,0 +1,25 @@ +from langchain_core.callbacks import BaseCallbackHandler +from langchain_core.outputs import LLMResult + +class MetadataCallbackHandler(BaseCallbackHandler): + """ + When passed as a callback handler, this stores the LLMResult's + `generation_info` dictionary in the `self.jai_metadata` instance attribute + after the provider fully processes an input. + + If used in a streaming chat handler: the `metadata` field of the final + `AgentStreamChunkMessage` should be set to `self.jai_metadata`. + + If used in a non-streaming chat handler: the `metadata` field of the + returned `AgentChatMessage` should be set to `self.jai_metadata`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.jai_metadata = {} + + def on_llm_end(self, response: LLMResult, **kwargs) -> None: + if not (len(response.generations) and len(response.generations[0])): + return + + self.jai_metadata = response.generations[0][0].generation_info diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index dc6753b5..447593dd 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -1,8 +1,9 @@ import asyncio import time -from typing import Dict, Type +from typing import Any, Dict, Type from uuid import uuid4 +from jupyter_ai.callback_handlers import MetadataCallbackHandler from jupyter_ai.models import ( AgentStreamChunkMessage, AgentStreamMessage, @@ -85,13 +86,16 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False): + def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False, metadata: Dict[str, Any] = {}): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, content=content, stream_complete=complete + id=stream_id, + content=content, + stream_complete=complete, + metadata=metadata ) for handler in self._root_chat_handlers.values(): @@ -104,6 +108,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals async def process_message(self, message: HumanChatMessage): self.get_llm_chain() received_first_chunk = False + assert self.llm_chain inputs = {"input": message.body} if "context" in self.prompt_template.input_variables: @@ -121,10 +126,13 @@ async def process_message(self, message: HumanChatMessage): # stream response in chunks. this works even if a provider does not # implement streaming, as `astream()` defaults to yielding `_call()` # when `_stream()` is not implemented on the LLM class. - assert self.llm_chain + metadata_handler = MetadataCallbackHandler() async for chunk in self.llm_chain.astream( inputs, - config={"configurable": {"last_human_msg": message}}, + config={ + "configurable": {"last_human_msg": message}, + "callbacks": [metadata_handler] + }, ): if not received_first_chunk: # when receiving the first chunk, close the pending message and @@ -142,7 +150,7 @@ async def process_message(self, message: HumanChatMessage): break # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True) + self._send_stream_chunk(stream_id, "", complete=True, metadata=metadata_handler.jai_metadata) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index e951ac6e..686aa15a 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -86,6 +86,14 @@ class BaseAgentMessage(BaseModel): The persona of the selected provider. If the selected provider is `None`, this defaults to a description of `JupyternautPersona`. """ + + metadata: Dict[str, Any] = {} + """ + Message metadata set by a provider after fully processing an input. The + contents of this dictionary are provider-dependent, and can be any + dictionary with string keys. This field is not to be displayed directly to + the user, and is intended solely for developer purposes. + """ class AgentChatMessage(BaseAgentMessage): @@ -101,9 +109,18 @@ class AgentStreamMessage(BaseAgentMessage): class AgentStreamChunkMessage(BaseModel): type: Literal["agent-stream-chunk"] = "agent-stream-chunk" id: str + """ID of the parent `AgentStreamMessage`.""" content: str + """The string to append to the `AgentStreamMessage` referenced by `id`.""" stream_complete: bool - """Indicates whether this chunk message completes the referenced stream.""" + """Indicates whether this chunk completes the stream referenced by `id`.""" + metadata: Dict[str, Any] = {} + """ + The metadata of the stream referenced by `id`. Metadata from the latest + chunk should override any metadata from previous chunks. See the docstring + on `BaseAgentMessage.metadata` for information. + """ + class HumanChatMessage(BaseModel): diff --git a/packages/jupyter-ai/src/chat_handler.ts b/packages/jupyter-ai/src/chat_handler.ts index 76c93a85..e1b1e332 100644 --- a/packages/jupyter-ai/src/chat_handler.ts +++ b/packages/jupyter-ai/src/chat_handler.ts @@ -170,6 +170,7 @@ export class ChatHandler implements IDisposable { } streamMessage.body += newMessage.content; + streamMessage.metadata = newMessage.metadata; if (newMessage.stream_complete) { streamMessage.complete = true; } diff --git a/packages/jupyter-ai/src/components/pending-messages.tsx b/packages/jupyter-ai/src/components/pending-messages.tsx index 3635e41e..c258c295 100644 --- a/packages/jupyter-ai/src/components/pending-messages.tsx +++ b/packages/jupyter-ai/src/components/pending-messages.tsx @@ -60,7 +60,8 @@ export function PendingMessages( time: lastMessage.time, body: '', reply_to: '', - persona: lastMessage.persona + persona: lastMessage.persona, + metadata: {} }); // timestamp format copied from ChatMessage diff --git a/packages/jupyter-ai/src/handler.ts b/packages/jupyter-ai/src/handler.ts index 43ab4520..b653015f 100644 --- a/packages/jupyter-ai/src/handler.ts +++ b/packages/jupyter-ai/src/handler.ts @@ -114,6 +114,7 @@ export namespace AiService { body: string; reply_to: string; persona: Persona; + metadata: Record; }; export type HumanChatMessage = { @@ -172,6 +173,7 @@ export namespace AiService { id: string; content: string; stream_complete: boolean; + metadata: Record; }; export type Request = ChatRequest | ClearRequest; From 57e1cdab348d3b85a0fcfcbdb0c86eeb777f64e3 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 23 Sep 2024 15:35:43 -0700 Subject: [PATCH 2/5] include message metadata in TestLLMWithStreaming --- packages/jupyter-ai-test/jupyter_ai_test/test_llms.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index c7c72666..8fa63241 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -48,10 +48,11 @@ def _stream( run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: - time.sleep(5) + time.sleep(1) yield GenerationChunk( - text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n" + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 20.\n\n", + generation_info={"test_metadata_field":"foobar"} ) - for i in range(1, 101): + for i in range(1, 21): time.sleep(0.5) yield GenerationChunk(text=f"{i}, ") From fca92dc03f70afda524487c75eabc571eec0a4f3 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 23 Sep 2024 15:37:24 -0700 Subject: [PATCH 3/5] tweak TestLLMWithStreaming parameters --- packages/jupyter-ai-test/jupyter_ai_test/test_llms.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index 8fa63241..1f4d92a4 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -50,9 +50,9 @@ def _stream( ) -> Iterator[GenerationChunk]: time.sleep(1) yield GenerationChunk( - text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 20.\n\n", + text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", generation_info={"test_metadata_field":"foobar"} ) - for i in range(1, 21): - time.sleep(0.5) + for i in range(1, 6): + time.sleep(0.2) yield GenerationChunk(text=f"{i}, ") From 9c3f169877ddf5bfad2f275d22a9b9ad0f4ab6ee Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Wed, 25 Sep 2024 10:24:05 -0700 Subject: [PATCH 4/5] pre-commit --- .../jupyter_ai_test/test_llms.py | 2 +- .../jupyter_ai/callback_handlers/__init__.py | 2 +- .../jupyter_ai/callback_handlers/metadata.py | 1 + .../jupyter_ai/chat_handlers/default.py | 19 ++++++++++++------- packages/jupyter-ai/jupyter_ai/models.py | 3 +-- .../src/components/chat-messages.tsx | 4 ++++ 6 files changed, 20 insertions(+), 11 deletions(-) diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index 1f4d92a4..17fa4265 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -51,7 +51,7 @@ def _stream( time.sleep(1) yield GenerationChunk( text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", - generation_info={"test_metadata_field":"foobar"} + generation_info={"test_metadata_field": "foobar"}, ) for i in range(1, 6): time.sleep(0.2) diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py index 5590fdcd..4567ecba 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py @@ -3,4 +3,4 @@ Not to be confused with Jupyter AI chat handlers. """ -from .metadata import MetadataCallbackHandler \ No newline at end of file +from .metadata import MetadataCallbackHandler diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index 249739aa..fa4407af 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -1,6 +1,7 @@ from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult + class MetadataCallbackHandler(BaseCallbackHandler): """ When passed as a callback handler, this stores the LLMResult's diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 447593dd..607dc92f 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -86,16 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False, metadata: Dict[str, Any] = {}): + def _send_stream_chunk( + self, + stream_id: str, + content: str, + complete: bool = False, + metadata: Dict[str, Any] = {}, + ): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, - content=content, - stream_complete=complete, - metadata=metadata + id=stream_id, content=content, stream_complete=complete, metadata=metadata ) for handler in self._root_chat_handlers.values(): @@ -131,7 +134,7 @@ async def process_message(self, message: HumanChatMessage): inputs, config={ "configurable": {"last_human_msg": message}, - "callbacks": [metadata_handler] + "callbacks": [metadata_handler], }, ): if not received_first_chunk: @@ -150,7 +153,9 @@ async def process_message(self, message: HumanChatMessage): break # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True, metadata=metadata_handler.jai_metadata) + self._send_stream_chunk( + stream_id, "", complete=True, metadata=metadata_handler.jai_metadata + ) async def make_context_prompt(self, human_msg: HumanChatMessage) -> str: return "\n\n".join( diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 686aa15a..9bd59ca2 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -86,7 +86,7 @@ class BaseAgentMessage(BaseModel): The persona of the selected provider. If the selected provider is `None`, this defaults to a description of `JupyternautPersona`. """ - + metadata: Dict[str, Any] = {} """ Message metadata set by a provider after fully processing an input. The @@ -122,7 +122,6 @@ class AgentStreamChunkMessage(BaseModel): """ - class HumanChatMessage(BaseModel): type: Literal["human"] = "human" id: str diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index aaae93a1..961b884b 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -74,6 +74,10 @@ function sortMessages( export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const collaborators = useCollaboratorsContext(); + if (props.message.type === 'agent-stream' && props.message.complete) { + console.log(props.message.metadata); + } + const sharedStyles: SxProps = { height: '24px', width: '24px' From 037ba3dcdfbabd0bbe307f71e4d96d0b7b6e4209 Mon Sep 17 00:00:00 2001 From: "David L. Qiu" Date: Mon, 23 Sep 2024 15:57:09 -0700 Subject: [PATCH 5/5] default to empty dict if no generation_info --- packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index fa4407af..145cab31 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -23,4 +23,4 @@ def on_llm_end(self, response: LLMResult, **kwargs) -> None: if not (len(response.generations) and len(response.generations[0])): return - self.jai_metadata = response.generations[0][0].generation_info + self.jai_metadata = response.generations[0][0].generation_info or {}