diff --git a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py index 5e8274de4..21ad0e7b0 100644 --- a/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py +++ b/libs/langgraph/langgraph/prebuilt/chat_agent_executor.py @@ -1,8 +1,13 @@ -from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union +from typing import Callable, Literal, Optional, Sequence, Type, TypeVar, Union, cast -from langchain_core.language_models import BaseChatModel +from langchain_core.language_models import BaseChatModel, LanguageModelLike from langchain_core.messages import AIMessage, BaseMessage, SystemMessage, ToolMessage -from langchain_core.runnables import Runnable, RunnableConfig, RunnableLambda +from langchain_core.runnables import ( + Runnable, + RunnableBinding, + RunnableConfig, + RunnableLambda, +) from langchain_core.tools import BaseTool from typing_extensions import Annotated, TypedDict @@ -115,9 +120,43 @@ def _get_model_preprocessing_runnable( return _get_state_modifier_runnable(state_modifier) +def _should_bind_tools(model: LanguageModelLike, tools: Sequence[BaseTool]) -> bool: + if not isinstance(model, RunnableBinding): + return False + + if "tools" not in model.kwargs: + return False + + bound_tools = model.kwargs["tools"] + if len(tools) != len(bound_tools): + raise ValueError( + "Number of tools in the model.bind_tools() and tools passed to create_react_agent must match" + ) + + tool_names = set(tool.name for tool in tools) + bound_tool_names = set() + for bound_tool in bound_tools: + # OpenAI-style tool + if bound_tool.get("type") == "function": + bound_tool_name = bound_tool["function"]["name"] + # Anthropic-style tool + elif bound_tool.get("name"): + bound_tool_name = bound_tool["name"] + else: + # unknown tool type so we'll ignore it + continue + + bound_tool_names.add(bound_tool_name) + + if missing_tools := tool_names - bound_tool_names: + raise ValueError(f"Missing tools '{missing_tools}' in the model.bind_tools()") + + return True + + @deprecated_parameter("messages_modifier", "0.1.9", "state_modifier", removal="0.3.0") def create_react_agent( - model: BaseChatModel, + model: LanguageModelLike, tools: Union[ToolExecutor, Sequence[BaseTool], ToolNode], *, state_schema: Optional[StateSchemaType] = None, @@ -412,9 +451,12 @@ class Agent,Tools otherClass tool_classes = list(tools.tools_by_name.values()) tool_node = tools else: - tool_classes = tools - tool_node = ToolNode(tool_classes) - model = model.bind_tools(tool_classes) + tool_node = ToolNode(tools) + # get the tool functions wrapped in a tool class from the ToolNode + tool_classes = list(tool_node.tools_by_name.values()) + + if _should_bind_tools(model, tool_classes): + model = cast(BaseChatModel, model).bind_tools(tool_classes) # Define the function that determines whether to continue or not def should_continue(state: AgentState) -> Literal["tools", "__end__"]: diff --git a/libs/langgraph/tests/test_prebuilt.py b/libs/langgraph/tests/test_prebuilt.py index 6e1cd081a..42ab73957 100644 --- a/libs/langgraph/tests/test_prebuilt.py +++ b/libs/langgraph/tests/test_prebuilt.py @@ -6,6 +6,7 @@ Callable, Dict, List, + Literal, Optional, Sequence, Type, @@ -49,6 +50,7 @@ class FakeToolCallingModel(BaseChatModel): tool_calls: Optional[list[list[ToolCall]]] = None index: int = 0 + tool_style: Literal["openai", "anthropic"] = "openai" def _generate( self, @@ -79,7 +81,31 @@ def bind_tools( tools: Sequence[Union[Dict[str, Any], Type[BaseModel], Callable, BaseTool]], **kwargs: Any, ) -> Runnable[LanguageModelInput, BaseMessage]: - return self + tool_dicts = [] + for tool in tools: + if not isinstance(tool, BaseTool): + raise TypeError( + "Only BaseTool is supported by FakeToolCallingModel.bind_tools" + ) + + # NOTE: this is a simplified tool spec for testing purposes only + if self.tool_style == "openai": + tool_dicts.append( + { + "type": "function", + "function": { + "name": tool.name, + }, + } + ) + elif self.tool_style == "anthropic": + tool_dicts.append( + { + "name": tool.name, + } + ) + + return self.bind(tools=tool_dicts) @pytest.mark.parametrize("checkpointer_name", ALL_CHECKPOINTERS_SYNC) @@ -242,6 +268,58 @@ def test_runnable_state_modifier(): assert response == expected_response +@pytest.mark.parametrize("tool_style", ["openai", "anthropic"]) +def test_model_with_tools(tool_style: str): + model = FakeToolCallingModel(tool_style=tool_style) + + @dec_tool + def tool1(some_val: int) -> str: + """Tool 1 docstring.""" + return f"Tool 1: {some_val}" + + @dec_tool + def tool2(some_val: int) -> str: + """Tool 2 docstring.""" + return f"Tool 2: {some_val}" + + # check valid agent constructor + agent = create_react_agent(model.bind_tools([tool1, tool2]), [tool1, tool2]) + result = agent.nodes["tools"].invoke( + { + "messages": [ + AIMessage( + "hi?", + tool_calls=[ + { + "name": "tool1", + "args": {"some_val": 2}, + "id": "some 1", + }, + { + "name": "tool2", + "args": {"some_val": 2}, + "id": "some 2", + }, + ], + ) + ] + } + ) + tool_messages: ToolMessage = result["messages"][-2:] + for tool_message in tool_messages: + assert tool_message.type == "tool" + assert tool_message.content in {"Tool 1: 2", "Tool 2: 2"} + assert tool_message.tool_call_id in {"some 1", "some 2"} + + # test mismatching tool lengths + with pytest.raises(ValueError): + create_react_agent(model.bind_tools([tool1]), [tool1, tool2]) + + # test missing bound tools + with pytest.raises(ValueError): + create_react_agent(model.bind_tools([tool1]), [tool2]) + + async def test_tool_node(): def tool1(some_val: int, some_other_val: str) -> str: """Tool 1 docstring."""