From 9d83ce687ba2ab1ab9f887de39725591a3d6ea90 Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 11:27:04 +0700 Subject: [PATCH 1/9] Add validated data args in default factory --- .../type_serializers/with_default.rs | 2 +- src/validators/with_default.rs | 18 +++++++-- tests/validators/test_model_fields.py | 2 +- tests/validators/test_typed_dict.py | 2 +- tests/validators/test_with_default.py | 40 +++++++++++++++++-- 5 files changed, 53 insertions(+), 11 deletions(-) diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index 6d510ce55..d153f24f0 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -72,6 +72,6 @@ impl TypeSerializer for WithDefaultSerializer { } fn get_default(&self, py: Python) -> PyResult> { - self.default.default_value(py) + self.default.default_value(py, &None) } } diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index e03830e77..da3551a18 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -1,8 +1,8 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; -use pyo3::types::PyDict; use pyo3::types::PyString; +use pyo3::types::{PyDict, PyTuple}; use pyo3::PyTraverseError; use pyo3::PyVisit; @@ -42,10 +42,20 @@ impl DefaultType { } } - pub fn default_value(&self, py: Python) -> PyResult> { + pub fn default_value(&self, py: Python, validated_data: &Option>) -> PyResult> { match self { Self::Default(ref default) => Ok(Some(default.clone_ref(py))), - Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)), + Self::DefaultFactory(ref default_factory) => { + let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?; + let default_factory_args: &Bound = co_vars.downcast_bound::(py)?; + + let result = if default_factory_args.len() >= 1 && !validated_data.is_none() { + default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + } else { + default_factory.call0(py) + }?; + Ok(Some(result)) + } Self::None => Ok(None), } } @@ -163,7 +173,7 @@ impl Validator for WithDefaultValidator { outer_loc: Option>, state: &mut ValidationState<'_, 'py>, ) -> ValResult> { - match self.default.default_value(py)? { + match self.default.default_value(py, &state.extra().data)? { Some(stored_dft) => { let dft: Py = if self.copy_default { let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap()); diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 8a22d96c3..ebae5932b 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1471,7 +1471,7 @@ def test_with_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, "() missing 1 required positional argument: 'x'"), + (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index dc18cd86e..53541ed2d 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -879,7 +879,7 @@ def test_field_required_and_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, "() missing 1 required positional argument: 'x'"), + (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 189a555bb..cc6f28f42 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -240,7 +240,7 @@ def broken(x): v.validate_python('wrong') -def test_typed_dict_error(): +def test_typed_dict(): v = SchemaValidator( { 'type': 'typed-dict', @@ -248,14 +248,17 @@ def test_typed_dict_error(): 'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, 'y': { 'type': 'typed-dict-field', - 'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default_factory': lambda y: y * 2}, + 'schema': { + 'type': 'default', + 'schema': {'type': 'str'}, + 'default_factory': lambda v_data: v_data['x'] + ' and y', + }, }, }, } ) assert v.validate_python({'x': 'x', 'y': 'y'}) == {'x': 'x', 'y': 'y'} - with pytest.raises(TypeError, match=r"\(\) missing 1 required positional argument: 'y'"): - v.validate_python({'x': 'x'}) + assert v.validate_python({'x': 'x value'}) == {'x': 'x value', 'y': 'x value and y'} def test_on_error_default_not_int(): @@ -812,3 +815,32 @@ def _raise(ex: Exception) -> None: v.validate_python(input_value) assert exc_info.value.errors(include_url=False, include_context=False) == expected + + +def test_typed_dict_default_factory(): + v = SchemaValidator( + { + 'type': 'typed-dict', + 'fields': { + 'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, + 'y': { + 'type': 'typed-dict-field', + 'schema': { + 'type': 'default', + 'schema': {'type': 'str'}, + 'default_factory': lambda v_data: v_data['x'] + ' and y', + }, + }, + 'z': { + 'type': 'typed-dict-field', + 'schema': { + 'type': 'default', + 'schema': {'type': 'str'}, + 'default_factory': lambda v_data: v_data['y'] + ' and z', + }, + }, + }, + } + ) + assert v.validate_python({'x': 'x', 'y': 'y', 'z': 'z'}) == {'x': 'x', 'y': 'y', 'z': 'z'} + assert v.validate_python({'x': 'x'}) == {'x': 'x', 'y': 'x and y', 'z': 'x and y and z'} From 72fe9fbe17d25594b438595ab5d58b8b130421cf Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 12:15:27 +0700 Subject: [PATCH 2/9] fix error message depend on python version --- tests/validators/test_model_fields.py | 7 ++++++- tests/validators/test_typed_dict.py | 8 +++++++- 2 files changed, 13 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index ebae5932b..c8f022795 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1471,7 +1471,12 @@ def test_with_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), + ( + lambda x: 'a' + x, + "unsupported operand type(s) for +: 'str' and 'dict'" + if sys.version_info <= (3, 10) + else 'can only concatenate str (not "dict") to str', + ), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 53541ed2d..1789c5d6a 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -2,6 +2,7 @@ import math import platform import re +import sys import weakref from typing import Any, Dict, Mapping, Union @@ -879,7 +880,12 @@ def test_field_required_and_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), + ( + lambda x: 'a' + x, + "unsupported operand type(s) for +: 'str' and 'dict'" + if sys.version_info <= (3, 10) + else 'can only concatenate str (not "dict") to str', + ), ], ) def test_bad_default_factory(default_factory, error_message): From d4bf43ce1a7c6445f3ca6844287f3517538fa75b Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 12:32:02 +0700 Subject: [PATCH 3/9] fix error message depend on python version --- tests/validators/test_model_fields.py | 7 +------ tests/validators/test_typed_dict.py | 8 +------- 2 files changed, 2 insertions(+), 13 deletions(-) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index c8f022795..412a5fbdc 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1471,12 +1471,7 @@ def test_with_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - ( - lambda x: 'a' + x, - "unsupported operand type(s) for +: 'str' and 'dict'" - if sys.version_info <= (3, 10) - else 'can only concatenate str (not "dict") to str', - ), + (lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'"), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 1789c5d6a..7437064b8 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -2,7 +2,6 @@ import math import platform import re -import sys import weakref from typing import Any, Dict, Mapping, Union @@ -880,12 +879,7 @@ def test_field_required_and_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - ( - lambda x: 'a' + x, - "unsupported operand type(s) for +: 'str' and 'dict'" - if sys.version_info <= (3, 10) - else 'can only concatenate str (not "dict") to str', - ), + (lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'"), ], ) def test_bad_default_factory(default_factory, error_message): From 3245cc7a2635fb8b4a2273bd8e39a8a40eb9e4dc Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 14:56:23 +0700 Subject: [PATCH 4/9] fix error message depend on python version --- src/validators/with_default.rs | 17 ++++++++++++----- tests/validators/test_model_fields.py | 2 +- tests/validators/test_typed_dict.py | 2 +- 3 files changed, 14 insertions(+), 7 deletions(-) diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index da3551a18..7aeb96842 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -1,8 +1,8 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; -use pyo3::types::PyString; use pyo3::types::{PyDict, PyTuple}; +use pyo3::types::{PyFunction, PyString}; use pyo3::PyTraverseError; use pyo3::PyVisit; @@ -46,14 +46,21 @@ impl DefaultType { match self { Self::Default(ref default) => Ok(Some(default.clone_ref(py))), Self::DefaultFactory(ref default_factory) => { - let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?; - let default_factory_args: &Bound = co_vars.downcast_bound::(py)?; + let is_func = default_factory.downcast_bound::(py).is_ok(); - let result = if default_factory_args.len() >= 1 && !validated_data.is_none() { - default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + let result = if is_func { + let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?; + let default_factory_args: &Bound = co_vars.downcast_bound::(py)?; + + if default_factory_args.len() >= 1 && !validated_data.is_none() { + default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + } else { + default_factory.call0(py) + } } else { default_factory.call0(py) }?; + Ok(Some(result)) } Self::None => Ok(None), diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 412a5fbdc..ebae5932b 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1471,7 +1471,7 @@ def test_with_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'"), + (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 7437064b8..53541ed2d 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -879,7 +879,7 @@ def test_field_required_and_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'"), + (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), ], ) def test_bad_default_factory(default_factory, error_message): From 8be009674cd4462a29b431fc17fc4c7acffc66f2 Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 15:37:06 +0700 Subject: [PATCH 5/9] fix error message depend on python version --- src/validators/with_default.rs | 14 ++++++++++---- tests/validators/test_with_default.py | 3 +-- 2 files changed, 11 insertions(+), 6 deletions(-) diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index 7aeb96842..ba9712afa 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -1,8 +1,8 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; +use pyo3::types::PyString; use pyo3::types::{PyDict, PyTuple}; -use pyo3::types::{PyFunction, PyString}; use pyo3::PyTraverseError; use pyo3::PyVisit; @@ -46,14 +46,20 @@ impl DefaultType { match self { Self::Default(ref default) => Ok(Some(default.clone_ref(py))), Self::DefaultFactory(ref default_factory) => { - let is_func = default_factory.downcast_bound::(py).is_ok(); + // MASSIVE HACK! PyFunction doesn't exist for PyPy, + // ref from https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635 + let is_func = default_factory.getattr(py, "__class__")?.to_string() == ""; let result = if is_func { let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?; let default_factory_args: &Bound = co_vars.downcast_bound::(py)?; - if default_factory_args.len() >= 1 && !validated_data.is_none() { - default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + if default_factory_args.len() >= 1 { + if validated_data.is_none() { + default_factory.call1(py, ({},)) + } else { + default_factory.call1(py, (validated_data.as_deref().unwrap(),)) + } } else { default_factory.call0(py) } diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index cc6f28f42..8c4fba23e 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -236,8 +236,7 @@ def broken(x): ) assert v.validate_python(42) == 42 assert v.validate_python('42') == 42 - with pytest.raises(TypeError, match=r"broken\(\) missing 1 required positional argument: 'x'"): - v.validate_python('wrong') + assert v.validate_python('wrong') == 7 def test_typed_dict(): From 0e0dbc363b08920bf3a893569f554e6c12f2844f Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 15:52:19 +0700 Subject: [PATCH 6/9] Fix test fail in pypy and python<3.10 --- tests/validators/test_model_fields.py | 8 +++++++- tests/validators/test_typed_dict.py | 8 +++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index ebae5932b..8888b9ee2 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1,4 +1,5 @@ import math +import platform import re import sys from dataclasses import dataclass @@ -1471,7 +1472,12 @@ def test_with_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), + ( + lambda x: 'a' + x, + "unsupported operand type(s) for +: 'str' and 'dict'" + if sys.version_info <= (3, 10) and (platform.python_implementation() == 'PyPy') + else 'can only concatenate str (not "dict") to str', + ), ], ) def test_bad_default_factory(default_factory, error_message): diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 53541ed2d..af164dc4d 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -2,6 +2,7 @@ import math import platform import re +import sys import weakref from typing import Any, Dict, Mapping, Union @@ -879,7 +880,12 @@ def test_field_required_and_default_factory(): 'default_factory,error_message', [ (lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"), - (lambda x: 'a' + x, 'can only concatenate str (not "dict") to str'), + ( + lambda x: 'a' + x, + "unsupported operand type(s) for +: 'str' and 'dict'" + if sys.version_info <= (3, 10) and (platform.python_implementation() == 'PyPy') + else 'can only concatenate str (not "dict") to str', + ), ], ) def test_bad_default_factory(default_factory, error_message): From 5e5c249484c145e42984d3d9a636ab3aba3433c8 Mon Sep 17 00:00:00 2001 From: nix Date: Mon, 7 Oct 2024 19:19:46 +0700 Subject: [PATCH 7/9] Fix test fail in pypy and python<3.10 --- tests/validators/test_model_fields.py | 3 ++- tests/validators/test_typed_dict.py | 3 ++- 2 files changed, 4 insertions(+), 2 deletions(-) diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index 8888b9ee2..a6f400199 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1475,7 +1475,8 @@ def test_with_default_factory(): ( lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'" - if sys.version_info <= (3, 10) and (platform.python_implementation() == 'PyPy') + # fix pypy3.10 CI failed + if sys.version_info < (3, 11) and (platform.python_implementation() == 'PyPy') else 'can only concatenate str (not "dict") to str', ), ], diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index af164dc4d..0544ee41b 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -883,7 +883,8 @@ def test_field_required_and_default_factory(): ( lambda x: 'a' + x, "unsupported operand type(s) for +: 'str' and 'dict'" - if sys.version_info <= (3, 10) and (platform.python_implementation() == 'PyPy') + # fix pypy3.10 CI failed + if sys.version_info < (3, 11) and (platform.python_implementation() == 'PyPy') else 'can only concatenate str (not "dict") to str', ), ], From 08fcdacc16a0ab7fca4d3afe41b5b233e41f827f Mon Sep 17 00:00:00 2001 From: nix Date: Sat, 12 Oct 2024 18:19:59 +0700 Subject: [PATCH 8/9] rework using bool flag for passing arg --- python/pydantic_core/core_schema.py | 10 ++- .../type_serializers/with_default.rs | 2 +- src/validators/with_default.rs | 38 +++++----- tests/test_schema_functions.py | 20 +++++- tests/validators/test_model_fields.py | 16 ++--- tests/validators/test_typed_dict.py | 16 ++--- tests/validators/test_with_default.py | 69 +++++++++---------- 7 files changed, 91 insertions(+), 80 deletions(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index b1c4c13b5..e78436592 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -10,6 +10,7 @@ from collections.abc import Mapping from datetime import date, datetime, time, timedelta from decimal import Decimal +from inspect import isfunction, signature from typing import TYPE_CHECKING, Any, Callable, Dict, Hashable, List, Pattern, Set, Tuple, Type, Union from typing_extensions import deprecated @@ -94,6 +95,8 @@ class CoreConfig(TypedDict, total=False): revalidate_instances: Literal['always', 'never', 'subclass-instances'] # whether to validate default values during validation, default False validate_default: bool + # + default_factory_has_args: bool # used on typed-dicts and arguments populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1 # fields related to string fields only @@ -2375,6 +2378,7 @@ class WithDefaultSchema(TypedDict, total=False): default_factory: Callable[[], Any] on_error: Literal['raise', 'omit', 'default'] # default: 'raise' validate_default: bool # default: False + default_factory_has_args: bool strict: bool ref: str metadata: Dict[str, Any] @@ -2385,7 +2389,7 @@ def with_default_schema( schema: CoreSchema, *, default: Any = PydanticUndefined, - default_factory: Callable[[], Any] | None = None, + default_factory: Callable[[Dict[str, Any]], Any] | Callable[[], Any] | None = None, on_error: Literal['raise', 'omit', 'default'] | None = None, validate_default: bool | None = None, strict: bool | None = None, @@ -2418,10 +2422,14 @@ def with_default_schema( metadata: Any other information you want to include with the schema, not used by pydantic-core serialization: Custom serialization schema """ + + has_arg = isfunction(default_factory) and len(signature(default_factory).parameters) > 0 + s = _dict_not_none( type='default', schema=schema, default_factory=default_factory, + default_factory_has_args=has_arg, on_error=on_error, validate_default=validate_default, strict=strict, diff --git a/src/serializers/type_serializers/with_default.rs b/src/serializers/type_serializers/with_default.rs index d153f24f0..c97837762 100644 --- a/src/serializers/type_serializers/with_default.rs +++ b/src/serializers/type_serializers/with_default.rs @@ -72,6 +72,6 @@ impl TypeSerializer for WithDefaultSerializer { } fn get_default(&self, py: Python) -> PyResult> { - self.default.default_value(py, &None) + self.default.default_value(py, &None, false) } } diff --git a/src/validators/with_default.rs b/src/validators/with_default.rs index ba9712afa..f466ce6d6 100644 --- a/src/validators/with_default.rs +++ b/src/validators/with_default.rs @@ -1,8 +1,8 @@ use pyo3::intern; use pyo3::prelude::*; use pyo3::sync::GILOnceCell; +use pyo3::types::PyDict; use pyo3::types::PyString; -use pyo3::types::{PyDict, PyTuple}; use pyo3::PyTraverseError; use pyo3::PyVisit; @@ -42,31 +42,24 @@ impl DefaultType { } } - pub fn default_value(&self, py: Python, validated_data: &Option>) -> PyResult> { + pub fn default_value( + &self, + py: Python, + validated_data: &Option>, + pass_arg: bool, + ) -> PyResult> { match self { Self::Default(ref default) => Ok(Some(default.clone_ref(py))), Self::DefaultFactory(ref default_factory) => { - // MASSIVE HACK! PyFunction doesn't exist for PyPy, - // ref from https://github.com/pydantic/pydantic-core/pull/161#discussion_r917257635 - let is_func = default_factory.getattr(py, "__class__")?.to_string() == ""; - - let result = if is_func { - let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?; - let default_factory_args: &Bound = co_vars.downcast_bound::(py)?; - - if default_factory_args.len() >= 1 { - if validated_data.is_none() { - default_factory.call1(py, ({},)) - } else { - default_factory.call1(py, (validated_data.as_deref().unwrap(),)) - } + let result = if pass_arg { + if validated_data.is_none() { + default_factory.call1(py, ({},)) } else { - default_factory.call0(py) + default_factory.call1(py, (validated_data.as_deref().unwrap(),)) } } else { default_factory.call0(py) }?; - Ok(Some(result)) } Self::None => Ok(None), @@ -96,6 +89,7 @@ pub struct WithDefaultValidator { on_error: OnError, validator: Box, validate_default: bool, + default_factory_has_args: bool, copy_default: bool, name: String, undefined: PyObject, @@ -140,12 +134,15 @@ impl BuildValidator for WithDefaultValidator { }; let name = format!("{}[{}]", Self::EXPECTED_TYPE, validator.get_name()); + let default_factory_has_args = + schema_or_config_same(schema, config, intern!(py, "default_factory_has_args"))?.unwrap_or(false); Ok(Self { default, on_error, validator, validate_default: schema_or_config_same(schema, config, intern!(py, "validate_default"))?.unwrap_or(false), + default_factory_has_args, copy_default, name, undefined: PydanticUndefinedType::new(py).to_object(py), @@ -186,7 +183,10 @@ impl Validator for WithDefaultValidator { outer_loc: Option>, state: &mut ValidationState<'_, 'py>, ) -> ValResult> { - match self.default.default_value(py, &state.extra().data)? { + match self + .default + .default_value(py, &state.extra().data, self.default_factory_has_args)? + { Some(stored_dft) => { let dft: Py = if self.copy_default { let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap()); diff --git a/tests/test_schema_functions.py b/tests/test_schema_functions.py index a15adfca5..7817d1d3e 100644 --- a/tests/test_schema_functions.py +++ b/tests/test_schema_functions.py @@ -17,6 +17,10 @@ def make_5(): return 5 +def make_5_with_arg(d): + return 5 + + class MyModel: __slots__ = '__dict__', '__pydantic_fields_set__', '__pydantic_extra__', '__pydantic_private__' @@ -143,17 +147,27 @@ def args(*args, **kwargs): ( core_schema.with_default_schema, args({'type': 'int'}, default=5), - {'type': 'default', 'schema': {'type': 'int'}, 'default': 5}, + {'type': 'default', 'schema': {'type': 'int'}, 'default': 5, 'default_factory_has_args': False}, ), ( core_schema.with_default_schema, args({'type': 'int'}, default=None), - {'type': 'default', 'schema': {'type': 'int'}, 'default': None}, + {'type': 'default', 'schema': {'type': 'int'}, 'default': None, 'default_factory_has_args': False}, ), ( core_schema.with_default_schema, args({'type': 'int'}, default_factory=make_5), - {'type': 'default', 'schema': {'type': 'int'}, 'default_factory': make_5}, + {'type': 'default', 'schema': {'type': 'int'}, 'default_factory': make_5, 'default_factory_has_args': False}, + ), + ( + core_schema.with_default_schema, + args({'type': 'int'}, default_factory=make_5_with_arg), + { + 'type': 'default', + 'schema': {'type': 'int'}, + 'default_factory': make_5_with_arg, + 'default_factory_has_args': True, + }, ), (core_schema.nullable_schema, args({'type': 'int'}), {'type': 'nullable', 'schema': {'type': 'int'}}), ( diff --git a/tests/validators/test_model_fields.py b/tests/validators/test_model_fields.py index a6f400199..3c7518cc9 100644 --- a/tests/validators/test_model_fields.py +++ b/tests/validators/test_model_fields.py @@ -1483,15 +1483,13 @@ def test_with_default_factory(): ) def test_bad_default_factory(default_factory, error_message): v = SchemaValidator( - { - 'type': 'model-fields', - 'fields': { - 'x': { - 'type': 'model-field', - 'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default_factory': default_factory}, - } - }, - } + core_schema.model_fields_schema( + fields={ + 'x': core_schema.model_field( + core_schema.with_default_schema(core_schema.str_schema(), default_factory=default_factory) + ) + } + ), ) with pytest.raises(TypeError, match=re.escape(error_message)): v.validate_python({}) diff --git a/tests/validators/test_typed_dict.py b/tests/validators/test_typed_dict.py index 0544ee41b..5a59b565f 100644 --- a/tests/validators/test_typed_dict.py +++ b/tests/validators/test_typed_dict.py @@ -891,15 +891,13 @@ def test_field_required_and_default_factory(): ) def test_bad_default_factory(default_factory, error_message): v = SchemaValidator( - { - 'type': 'typed-dict', - 'fields': { - 'x': { - 'type': 'typed-dict-field', - 'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default_factory': default_factory}, - } - }, - } + core_schema.typed_dict_schema( + fields={ + 'x': core_schema.typed_dict_field( + core_schema.with_default_schema(core_schema.str_schema(), default_factory=default_factory) + ) + } + ), ) with pytest.raises(TypeError, match=re.escape(error_message)): v.validate_python({}) diff --git a/tests/validators/test_with_default.py b/tests/validators/test_with_default.py index 8c4fba23e..259763a84 100644 --- a/tests/validators/test_with_default.py +++ b/tests/validators/test_with_default.py @@ -228,11 +228,11 @@ def broken(): def test_factory_type_error(): - def broken(x): + def broken(): return 7 v = SchemaValidator( - {'type': 'default', 'schema': {'type': 'int'}, 'on_error': 'default', 'default_factory': broken} + core_schema.with_default_schema(core_schema.int_schema(), default_factory=broken, on_error='default'), ) assert v.validate_python(42) == 42 assert v.validate_python('42') == 42 @@ -241,20 +241,18 @@ def broken(x): def test_typed_dict(): v = SchemaValidator( - { - 'type': 'typed-dict', - 'fields': { - 'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'y': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'str'}, - 'default_factory': lambda v_data: v_data['x'] + ' and y', - }, - }, - }, - } + core_schema.typed_dict_schema( + fields={ + 'x': core_schema.typed_dict_field( + core_schema.str_schema(), + ), + 'y': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.str_schema(), default_factory=lambda v_data: v_data['x'] + ' and y' + ) + ), + } + ) ) assert v.validate_python({'x': 'x', 'y': 'y'}) == {'x': 'x', 'y': 'y'} assert v.validate_python({'x': 'x value'}) == {'x': 'x value', 'y': 'x value and y'} @@ -818,28 +816,23 @@ def _raise(ex: Exception) -> None: def test_typed_dict_default_factory(): v = SchemaValidator( - { - 'type': 'typed-dict', - 'fields': { - 'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}}, - 'y': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'str'}, - 'default_factory': lambda v_data: v_data['x'] + ' and y', - }, - }, - 'z': { - 'type': 'typed-dict-field', - 'schema': { - 'type': 'default', - 'schema': {'type': 'str'}, - 'default_factory': lambda v_data: v_data['y'] + ' and z', - }, - }, - }, - } + core_schema.typed_dict_schema( + fields={ + 'x': core_schema.typed_dict_field( + core_schema.str_schema(), + ), + 'y': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.str_schema(), default_factory=lambda v_data: v_data['x'] + ' and y' + ) + ), + 'z': core_schema.typed_dict_field( + core_schema.with_default_schema( + core_schema.str_schema(), default_factory=lambda v_data: v_data['y'] + ' and z' + ) + ), + } + ) ) assert v.validate_python({'x': 'x', 'y': 'y', 'z': 'z'}) == {'x': 'x', 'y': 'y', 'z': 'z'} assert v.validate_python({'x': 'x'}) == {'x': 'x', 'y': 'x and y', 'z': 'x and y and z'} From 8dd5ec2a71e0d705594aa8dfea812a3176e635bd Mon Sep 17 00:00:00 2001 From: nix Date: Sat, 12 Oct 2024 18:24:07 +0700 Subject: [PATCH 9/9] Add comments --- python/pydantic_core/core_schema.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python/pydantic_core/core_schema.py b/python/pydantic_core/core_schema.py index e78436592..d78849c3a 100644 --- a/python/pydantic_core/core_schema.py +++ b/python/pydantic_core/core_schema.py @@ -95,7 +95,7 @@ class CoreConfig(TypedDict, total=False): revalidate_instances: Literal['always', 'never', 'subclass-instances'] # whether to validate default values during validation, default False validate_default: bool - # + # whether to pass the validated data to the default_factory, computed base on signature of default_factory. default_factory_has_args: bool # used on typed-dicts and arguments populate_by_name: bool # replaces `allow_population_by_field_name` in pydantic v1