Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add metadata field to agent messages #1013

Merged
merged 5 commits into from
Sep 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 5 additions & 4 deletions packages/jupyter-ai-test/jupyter_ai_test/test_llms.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,11 @@ def _stream(
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
time.sleep(5)
time.sleep(1)
yield GenerationChunk(
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 100.\n\n"
text="Hello! This is a dummy response from a test LLM. I will now count from 1 to 5.\n\n",
generation_info={"test_metadata_field": "foobar"},
)
for i in range(1, 101):
time.sleep(0.5)
for i in range(1, 6):
time.sleep(0.2)
yield GenerationChunk(text=f"{i}, ")
6 changes: 6 additions & 0 deletions packages/jupyter-ai/jupyter_ai/callback_handlers/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,6 @@
"""
Provides classes which extend `langchain_core.callbacks:BaseCallbackHandler`.
Not to be confused with Jupyter AI chat handlers.
"""

from .metadata import MetadataCallbackHandler
26 changes: 26 additions & 0 deletions packages/jupyter-ai/jupyter_ai/callback_handlers/metadata.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
from langchain_core.callbacks import BaseCallbackHandler
from langchain_core.outputs import LLMResult


class MetadataCallbackHandler(BaseCallbackHandler):
"""
When passed as a callback handler, this stores the LLMResult's
`generation_info` dictionary in the `self.jai_metadata` instance attribute
after the provider fully processes an input.

If used in a streaming chat handler: the `metadata` field of the final
`AgentStreamChunkMessage` should be set to `self.jai_metadata`.

If used in a non-streaming chat handler: the `metadata` field of the
returned `AgentChatMessage` should be set to `self.jai_metadata`.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.jai_metadata = {}

def on_llm_end(self, response: LLMResult, **kwargs) -> None:
if not (len(response.generations) and len(response.generations[0])):
return

self.jai_metadata = response.generations[0][0].generation_info or {}
25 changes: 19 additions & 6 deletions packages/jupyter-ai/jupyter_ai/chat_handlers/default.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,9 @@
import asyncio
import time
from typing import Dict, Type
from typing import Any, Dict, Type
from uuid import uuid4

from jupyter_ai.callback_handlers import MetadataCallbackHandler
from jupyter_ai.models import (
AgentStreamChunkMessage,
AgentStreamMessage,
Expand Down Expand Up @@ -85,13 +86,19 @@ def _start_stream(self, human_msg: HumanChatMessage) -> str:

return stream_id

def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = False):
def _send_stream_chunk(
self,
stream_id: str,
content: str,
complete: bool = False,
metadata: Dict[str, Any] = {},
):
"""
Sends an `agent-stream-chunk` message containing content that should be
appended to an existing `agent-stream` message with ID `stream_id`.
"""
stream_chunk_msg = AgentStreamChunkMessage(
id=stream_id, content=content, stream_complete=complete
id=stream_id, content=content, stream_complete=complete, metadata=metadata
)

for handler in self._root_chat_handlers.values():
Expand All @@ -104,6 +111,7 @@ def _send_stream_chunk(self, stream_id: str, content: str, complete: bool = Fals
async def process_message(self, message: HumanChatMessage):
self.get_llm_chain()
received_first_chunk = False
assert self.llm_chain

inputs = {"input": message.body}
if "context" in self.prompt_template.input_variables:
Expand All @@ -121,10 +129,13 @@ async def process_message(self, message: HumanChatMessage):
# stream response in chunks. this works even if a provider does not
# implement streaming, as `astream()` defaults to yielding `_call()`
# when `_stream()` is not implemented on the LLM class.
assert self.llm_chain
metadata_handler = MetadataCallbackHandler()
async for chunk in self.llm_chain.astream(
inputs,
config={"configurable": {"last_human_msg": message}},
config={
"configurable": {"last_human_msg": message},
"callbacks": [metadata_handler],
},
):
if not received_first_chunk:
# when receiving the first chunk, close the pending message and
Expand All @@ -142,7 +153,9 @@ async def process_message(self, message: HumanChatMessage):
break

# complete stream after all chunks have been streamed
self._send_stream_chunk(stream_id, "", complete=True)
self._send_stream_chunk(
stream_id, "", complete=True, metadata=metadata_handler.jai_metadata
)

async def make_context_prompt(self, human_msg: HumanChatMessage) -> str:
return "\n\n".join(
Expand Down
18 changes: 17 additions & 1 deletion packages/jupyter-ai/jupyter_ai/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,14 @@ class BaseAgentMessage(BaseModel):
this defaults to a description of `JupyternautPersona`.
"""

metadata: Dict[str, Any] = {}
"""
Message metadata set by a provider after fully processing an input. The
contents of this dictionary are provider-dependent, and can be any
dictionary with string keys. This field is not to be displayed directly to
the user, and is intended solely for developer purposes.
"""


class AgentChatMessage(BaseAgentMessage):
type: Literal["agent"] = "agent"
Expand All @@ -101,9 +109,17 @@ class AgentStreamMessage(BaseAgentMessage):
class AgentStreamChunkMessage(BaseModel):
type: Literal["agent-stream-chunk"] = "agent-stream-chunk"
id: str
"""ID of the parent `AgentStreamMessage`."""
content: str
"""The string to append to the `AgentStreamMessage` referenced by `id`."""
stream_complete: bool
"""Indicates whether this chunk message completes the referenced stream."""
"""Indicates whether this chunk completes the stream referenced by `id`."""
metadata: Dict[str, Any] = {}
"""
The metadata of the stream referenced by `id`. Metadata from the latest
chunk should override any metadata from previous chunks. See the docstring
on `BaseAgentMessage.metadata` for information.
"""


class HumanChatMessage(BaseModel):
Expand Down
1 change: 1 addition & 0 deletions packages/jupyter-ai/src/chat_handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ export class ChatHandler implements IDisposable {
}

streamMessage.body += newMessage.content;
streamMessage.metadata = newMessage.metadata;
if (newMessage.stream_complete) {
streamMessage.complete = true;
}
Expand Down
4 changes: 4 additions & 0 deletions packages/jupyter-ai/src/components/chat-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -74,6 +74,10 @@ function sortMessages(
export function ChatMessageHeader(props: ChatMessageHeaderProps): JSX.Element {
const collaborators = useCollaboratorsContext();

if (props.message.type === 'agent-stream' && props.message.complete) {
console.log(props.message.metadata);
}

const sharedStyles: SxProps<Theme> = {
height: '24px',
width: '24px'
Expand Down
3 changes: 2 additions & 1 deletion packages/jupyter-ai/src/components/pending-messages.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -60,7 +60,8 @@ export function PendingMessages(
time: lastMessage.time,
body: '',
reply_to: '',
persona: lastMessage.persona
persona: lastMessage.persona,
metadata: {}
});

// timestamp format copied from ChatMessage
Expand Down
2 changes: 2 additions & 0 deletions packages/jupyter-ai/src/handler.ts
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ export namespace AiService {
body: string;
reply_to: string;
persona: Persona;
metadata: Record<string, any>;
};

export type HumanChatMessage = {
Expand Down Expand Up @@ -172,6 +173,7 @@ export namespace AiService {
id: string;
content: string;
stream_complete: boolean;
metadata: Record<string, any>;
};

export type Request = ChatRequest | ClearRequest;
Expand Down
Loading