diff --git a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py index 1f4d92a4..17fa4265 100644 --- a/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py +++ b/packages/jupyter-ai-test/jupyter_ai_test/test_llms.py @@ -51,7 +51,7 @@ def _stream( time.sleep(1) yield GenerationChunk( text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n", - generation_info={"test_metadata_field":"foobar"} + generation_info={"test_metadata_field": "foobar"}, ) for i in range(1, 6): time.sleep(0.2) diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py index 5590fdcd..4567ecba 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py @@ -3,4 +3,4 @@ Not to be confused with Jupyter AI chat handlers. """ -from .metadata import MetadataCallbackHandler \ No newline at end of file +from .metadata import MetadataCallbackHandler diff --git a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py index 249739aa..fa4407af 100644 --- a/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py +++ b/packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py @@ -1,6 +1,7 @@ from langchain_core.callbacks import BaseCallbackHandler from langchain_core.outputs import LLMResult + class MetadataCallbackHandler(BaseCallbackHandler): """ When passed as a callback handler, this stores the LLMResult's diff --git a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py index 0d4a479b..dd67a082 100644 --- a/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py +++ b/packages/jupyter-ai/jupyter_ai/chat_handlers/default.py @@ -82,16 +82,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str: return stream_id - def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False, metadata: Dict[str, Any] = {}): + def _send_stream_chunk( + self, + stream_id: str, + content: str, + complete: bool = False, + metadata: Dict[str, Any] = {}, + ): """ Sends an `agent-stream-chunk` message containing content that should be appended to an existing `agent-stream` message with ID `stream_id`. """ stream_chunk_msg = AgentStreamChunkMessage( - id=stream_id, - content=content, - stream_complete=complete, - metadata=metadata + id=stream_id, content=content, stream_complete=complete, metadata=metadata ) for handler in self._root_chat_handlers.values(): @@ -116,7 +119,7 @@ async def process_message(self, message: HumanChatMessage): {"input": message.body}, config={ "configurable": {"last_human_msg": message}, - "callbacks": [metadata_handler] + "callbacks": [metadata_handler], }, ): if not received_first_chunk: @@ -135,4 +138,6 @@ async def process_message(self, message: HumanChatMessage): break # complete stream after all chunks have been streamed - self._send_stream_chunk(stream_id, "", complete=True, metadata=metadata_handler.jai_metadata) + self._send_stream_chunk( + stream_id, "", complete=True, metadata=metadata_handler.jai_metadata + ) diff --git a/packages/jupyter-ai/jupyter_ai/models.py b/packages/jupyter-ai/jupyter_ai/models.py index 593ada1c..b0a06771 100644 --- a/packages/jupyter-ai/jupyter_ai/models.py +++ b/packages/jupyter-ai/jupyter_ai/models.py @@ -86,7 +86,7 @@ class BaseAgentMessage(BaseModel): The persona of the selected provider. If the selected provider is `None`, this defaults to a description of `JupyternautPersona`. """ - + metadata: Dict[str, Any] = {} """ Message metadata set by a provider after fully processing an input. The @@ -122,7 +122,6 @@ class AgentStreamChunkMessage(BaseModel): """ - class HumanChatMessage(BaseModel): type: Literal["human"] = "human" id: str diff --git a/packages/jupyter-ai/src/components/chat-messages.tsx b/packages/jupyter-ai/src/components/chat-messages.tsx index aaae93a1..961b884b 100644 --- a/packages/jupyter-ai/src/components/chat-messages.tsx +++ b/packages/jupyter-ai/src/components/chat-messages.tsx @@ -74,6 +74,10 @@ function sortMessages( export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element { const collaborators = useCollaboratorsContext(); + if (props.message.type === 'agent-stream' && props.message.complete) { + console.log(props.message.metadata); + } + const sharedStyles: SxProps = { height: '24px', width: '24px'