From 4b9570aed447b4897812da7c6730f2836c3e0bf9 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:28:46 -0400 Subject: [PATCH 1/5] x --- libs/core/langchain_core/utils/pydantic.py | 7 +++++-- libs/core/tests/unit_tests/test_tools.py | 15 +++++++++++++++ 2 files changed, 20 insertions(+), 2 deletions(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index 1352bcfafffb2..ea92a167c2b25 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -266,7 +266,7 @@ def _create_subset_model_v2( fn_description: Optional[str] = None, ) -> type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" - from pydantic import create_model + from pydantic import create_model, ConfigDict from pydantic.fields import FieldInfo descriptions_ = descriptions or {} @@ -278,7 +278,10 @@ def _create_subset_model_v2( if field.metadata: field_info.metadata = field.metadata fields[field_name] = (field.annotation, field_info) - rtn = create_model(name, **fields) # type: ignore + + rtn = create_model( + name, **fields, model_config=ConfigDict(arbitrary_types_allowed=True) + ) # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. # This is done to preserve __annotations__ when working with pydantic 2.x diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index a61cead53c23f..06eae1af4a366 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2090,3 +2090,18 @@ class FooSchema(BaseModel): with pytest.raises(NotImplementedError): assert tool.invoke("hello") == "hello" + + +def test_injected_arg_with_complex_type() -> None: + """Test that an injected tool arg can be a complex type.""" + + class Foo: + def __init__(self) -> None: + self.value = "bar" + + @tool + def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: + """Tool that has an injected tool arg.""" + return foo.value + + assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" From 6049ea54c783ec4cafa915687c69b94469252cc3 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:29:05 -0400 Subject: [PATCH 2/5] x --- libs/core/langchain_core/utils/pydantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index ea92a167c2b25..f2a79d32bfb26 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -266,7 +266,7 @@ def _create_subset_model_v2( fn_description: Optional[str] = None, ) -> type[pydantic.BaseModel]: """Create a pydantic model with a subset of the model fields.""" - from pydantic import create_model, ConfigDict + from pydantic import ConfigDict, create_model from pydantic.fields import FieldInfo descriptions_ = descriptions or {} From 4a743620e65ddfb25de0193781bf2c2198d48811 Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:33:20 -0400 Subject: [PATCH 3/5] add type ignore --- libs/core/langchain_core/utils/pydantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index f2a79d32bfb26..e91a1a65c6c41 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -279,7 +279,7 @@ def _create_subset_model_v2( field_info.metadata = field.metadata fields[field_name] = (field.annotation, field_info) - rtn = create_model( + rtn = create_model( # type: ignore name, **fields, model_config=ConfigDict(arbitrary_types_allowed=True) ) From b0ad37694622bb8a7ca86dbc8ba839e0d6b14e0c Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:33:44 -0400 Subject: [PATCH 4/5] x --- libs/core/langchain_core/utils/pydantic.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/langchain_core/utils/pydantic.py b/libs/core/langchain_core/utils/pydantic.py index e91a1a65c6c41..93375e09f348b 100644 --- a/libs/core/langchain_core/utils/pydantic.py +++ b/libs/core/langchain_core/utils/pydantic.py @@ -280,7 +280,7 @@ def _create_subset_model_v2( fields[field_name] = (field.annotation, field_info) rtn = create_model( # type: ignore - name, **fields, model_config=ConfigDict(arbitrary_types_allowed=True) + name, **fields, __config__=ConfigDict(arbitrary_types_allowed=True) ) # TODO(0.3): Determine if there is a more "pydantic" way to preserve annotations. From 55c77b465c5b2b7f4d81e690dc639423708a067b Mon Sep 17 00:00:00 2001 From: Eugene Yurtsev Date: Wed, 2 Oct 2024 12:39:38 -0400 Subject: [PATCH 5/5] x --- libs/core/tests/unit_tests/test_tools.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/libs/core/tests/unit_tests/test_tools.py b/libs/core/tests/unit_tests/test_tools.py index 06eae1af4a366..3eae40ede1e82 100644 --- a/libs/core/tests/unit_tests/test_tools.py +++ b/libs/core/tests/unit_tests/test_tools.py @@ -2104,4 +2104,4 @@ def injected_tool(x: int, foo: Annotated[Foo, InjectedToolArg]) -> str: """Tool that has an injected tool arg.""" return foo.value - assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" + assert injected_tool.invoke({"x": 5, "foo": Foo()}) == "bar" # type: ignore