Skip to content

Commit

Permalink
langgraph: check if model passed as runnable binding with tools in cr…
Browse files Browse the repository at this point in the history
…eate_react_agent (#1647)
  • Loading branch information
vbarda authored Sep 30, 2024
1 parent e332869 commit 78f3ee9
Show file tree
Hide file tree
Showing 2 changed files with 128 additions and 8 deletions.
56 changes: 49 additions & 7 deletions libs/langgraph/langgraph/prebuilt/chat_agent_executor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,13 @@
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union
from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast

from langchain_core.language_models import BaseChatModel
from langchain_core.language_models import BaseChatModel, LanguageModelLike
from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage
from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda
from langchain_core.runnables import (
Runnable,
RunnableBinding,
RunnableConfig,
RunnableLambda,
)
from langchain_core.tools import BaseTool
from typing_extensions import Annotated, TypedDict

Expand Down Expand Up @@ -115,9 +120,43 @@ def _get_model_preprocessing_runnable(
return _get_state_modifier_runnable(state_modifier)


def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool:
if not isinstance(model, RunnableBinding):
return False

if "tools" not in model.kwargs:
return False

bound_tools = model.kwargs["tools"]
if len(tools) != len(bound_tools):
raise ValueError(
"Number of tools in the model.bind_tools() and tools passed to create_react_agent must match"
)

tool_names = set(tool.name for tool in tools)
bound_tool_names = set()
for bound_tool in bound_tools:
# OpenAI-style tool
if bound_tool.get("type") == "function":
bound_tool_name = bound_tool["function"]["name"]
# Anthropic-style tool
elif bound_tool.get("name"):
bound_tool_name = bound_tool["name"]
else:
# unknown tool type so we'll ignore it
continue

bound_tool_names.add(bound_tool_name)

if missing_tools := tool_names - bound_tool_names:
raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()")

return True


@deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0")
def create_react_agent(
model: BaseChatModel,
model: LanguageModelLike,
tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode],
*,
state_schema: Optional[StateSchemaType] = None,
Expand Down Expand Up @@ -412,9 +451,12 @@ class Agent,Tools otherClass
tool_classes = list(tools.tools_by_name.values())
tool_node = tools
else:
tool_classes = tools
tool_node = ToolNode(tool_classes)
model = model.bind_tools(tool_classes)
tool_node = ToolNode(tools)
# get the tool functions wrapped in a tool class from the ToolNode
tool_classes = list(tool_node.tools_by_name.values())

if _should_bind_tools(model, tool_classes):
model = cast(BaseChatModel, model).bind_tools(tool_classes)

# Define the function that determines whether to continue or not
def should_continue(state: AgentState) -> Literal["tools", "__end__"]:
Expand Down
80 changes: 79 additions & 1 deletion libs/langgraph/tests/test_prebuilt.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
Callable,
Dict,
List,
Literal,
Optional,
Sequence,
Type,
Expand Down Expand Up @@ -49,6 +50,7 @@
class FakeToolCallingModel(BaseChatModel):
tool_calls: Optional[list[list[ToolCall]]] = None
index: int = 0
tool_style: Literal["openai", "anthropic"] = "openai"

def _generate(
self,
Expand Down Expand Up @@ -79,7 +81,31 @@ def bind_tools(
tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]],
**kwargs: Any,
) -> Runnable[LanguageModelInput, BaseMessage]:
return self
tool_dicts = []
for tool in tools:
if not isinstance(tool, BaseTool):
raise TypeError(
"Only BaseTool is supported by FakeToolCallingModel.bind_tools"
)

# NOTE: this is a simplified tool spec for testing purposes only
if self.tool_style == "openai":
tool_dicts.append(
{
"type": "function",
"function": {
"name": tool.name,
},
}
)
elif self.tool_style == "anthropic":
tool_dicts.append(
{
"name": tool.name,
}
)

return self.bind(tools=tool_dicts)


@pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC)
Expand Down Expand Up @@ -242,6 +268,58 @@ def test_runnable_state_modifier():
assert response == expected_response


@pytest.mark.parametrize("tool_style", ["openai", "anthropic"])
def test_model_with_tools(tool_style: str):
model = FakeToolCallingModel(tool_style=tool_style)

@dec_tool
def tool1(some_val: int) -> str:
"""Tool 1 docstring."""
return f"Tool 1: {some_val}"

@dec_tool
def tool2(some_val: int) -> str:
"""Tool 2 docstring."""
return f"Tool 2: {some_val}"

# check valid agent constructor
agent = create_react_agent(model.bind_tools([tool1, tool2]), [tool1, tool2])
result = agent.nodes["tools"].invoke(
{
"messages": [
AIMessage(
"hi?",
tool_calls=[
{
"name": "tool1",
"args": {"some_val": 2},
"id": "some 1",
},
{
"name": "tool2",
"args": {"some_val": 2},
"id": "some 2",
},
],
)
]
}
)
tool_messages: ToolMessage = result["messages"][-2:]
for tool_message in tool_messages:
assert tool_message.type == "tool"
assert tool_message.content in {"Tool 1: 2", "Tool 2: 2"}
assert tool_message.tool_call_id in {"some 1", "some 2"}

# test mismatching tool lengths
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool1, tool2])

# test missing bound tools
with pytest.raises(ValueError):
create_react_agent(model.bind_tools([tool1]), [tool2])


async def test_tool_node():
def tool1(some_val: int, some_other_val: str) -> str:
"""Tool 1 docstring."""
Expand Down

2 comments on commit 78f3ee9

@LeanVel
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not working properly when using ChatBedrock

@vbarda
Copy link
Collaborator Author

@vbarda vbarda commented on 78f3ee9 Sep 30, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not working properly when using ChatBedrock

@LeanVel There was a bug that should be fixed in 0.2.31. If you're still running into issues, feel free to open an issue / discussion!

Please sign in to comment.