Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add validated data args in default factory #1475

Open
wants to merge 12 commits into
base: main
Choose a base branch
from
2 changes: 1 addition & 1 deletion src/serializers/type_serializers/with_default.rs
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,6 @@ impl TypeSerializer for WithDefaultSerializer {
}

fn get_default(&self, py: Python) -> PyResult<Option<PyObject>> {
self.default.default_value(py)
self.default.default_value(py, &None)
}
}
31 changes: 27 additions & 4 deletions src/validators/with_default.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
use pyo3::intern;
use pyo3::prelude::*;
use pyo3::sync::GILOnceCell;
use pyo3::types::PyDict;
use pyo3::types::PyString;
use pyo3::types::{PyDict, PyTuple};
use pyo3::PyTraverseError;
use pyo3::PyVisit;

Expand Down Expand Up @@ -42,10 +42,33 @@ impl DefaultType {
}
}

pub fn default_value(&self, py: Python) -> PyResult<Option<PyObject>> {
pub fn default_value(&self, py: Python, validated_data: &Option<Bound<PyDict>>) -> PyResult<Option<PyObject>> {
match self {
Self::Default(ref default) => Ok(Some(default.clone_ref(py))),
Self::DefaultFactory(ref default_factory) => Ok(Some(default_factory.call0(py)?)),
Self::DefaultFactory(ref default_factory) => {
// MASSIVE HACK! PyFunction doesn't exist for PyPy,
// ref from https:/pydantic/pydantic-core/pull/161#discussion_r917257635
let is_func = default_factory.getattr(py, "__class__")?.to_string() == "<class 'function'>";
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we add both branches here so that we can hopefully remove the hack in the long run?

See

// the PyFunction::is_type_of(attr) catches `staticmethod`, but also any other function,
// I think that's better than including static methods in the yielded attributes,
// if someone really wants fields, they can use an explicit field, or a function to modify input
#[cfg(not(PyPy))]
if !is_bound && !attr.is_instance_of::<PyFunction>() {
return Some(Ok((name, attr)));
}
// MASSIVE HACK! PyFunction doesn't exist for PyPy,
// is_instance_of::<PyFunction> crashes with a null pointer, hence this hack, see
// https:/pydantic/pydantic-core/pull/161#discussion_r917257635
#[cfg(PyPy)]
if !is_bound && attr.get_type().to_string() != "<class 'function'>" {
return Some(Ok((name, attr)));
}
}
for an example.

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we even need the function check here? I assume if we got to this point we know this is a callable bc we've matched on DefaultFactory...

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, because there's a case ex: default_factory=str in python then the default_factory here will not have __code__ and raise error.


let result = if is_func {
let co_vars = default_factory.getattr(py, "__code__")?.getattr(py, "co_varnames")?;
let default_factory_args: &Bound<PyTuple> = co_vars.downcast_bound::<PyTuple>(py)?;

if default_factory_args.len() >= 1 {
if validated_data.is_none() {
default_factory.call1(py, ({},))
} else {
default_factory.call1(py, (validated_data.as_deref().unwrap(),))
}
} else {
default_factory.call0(py)
}
} else {
default_factory.call0(py)
}?;

Ok(Some(result))
}
Self::None => Ok(None),
}
}
Expand Down Expand Up @@ -163,7 +186,7 @@ impl Validator for WithDefaultValidator {
outer_loc: Option<impl Into<LocItem>>,
state: &mut ValidationState<'_, 'py>,
) -> ValResult<Option<PyObject>> {
match self.default.default_value(py)? {
match self.default.default_value(py, &state.extra().data)? {
Some(stored_dft) => {
let dft: Py<PyAny> = if self.copy_default {
let deepcopy_func = COPY_DEEPCOPY.get_or_init(py, || get_deepcopy(py).unwrap());
Expand Down
9 changes: 8 additions & 1 deletion tests/validators/test_model_fields.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
import platform
import re
import sys
from dataclasses import dataclass
Expand Down Expand Up @@ -1471,7 +1472,13 @@ def test_with_default_factory():
'default_factory,error_message',
[
(lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"),
(lambda x: 'a' + x, "<lambda>() missing 1 required positional argument: 'x'"),
(
lambda x: 'a' + x,
"unsupported operand type(s) for +: 'str' and 'dict'"
# fix pypy3.10 CI failed
if sys.version_info < (3, 11) and (platform.python_implementation() == 'PyPy')
else 'can only concatenate str (not "dict") to str',
),
],
)
def test_bad_default_factory(default_factory, error_message):
Expand Down
9 changes: 8 additions & 1 deletion tests/validators/test_typed_dict.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import math
import platform
import re
import sys
import weakref
from typing import Any, Dict, Mapping, Union

Expand Down Expand Up @@ -879,7 +880,13 @@ def test_field_required_and_default_factory():
'default_factory,error_message',
[
(lambda: 1 + 'a', "unsupported operand type(s) for +: 'int' and 'str'"),
(lambda x: 'a' + x, "<lambda>() missing 1 required positional argument: 'x'"),
(
lambda x: 'a' + x,
"unsupported operand type(s) for +: 'str' and 'dict'"
# fix pypy3.10 CI failed
if sys.version_info < (3, 11) and (platform.python_implementation() == 'PyPy')
else 'can only concatenate str (not "dict") to str',
),
],
)
def test_bad_default_factory(default_factory, error_message):
Expand Down
43 changes: 37 additions & 6 deletions tests/validators/test_with_default.py
Original file line number Diff line number Diff line change
Expand Up @@ -236,26 +236,28 @@ def broken(x):
)
assert v.validate_python(42) == 42
assert v.validate_python('42') == 42
with pytest.raises(TypeError, match=r"broken\(\) missing 1 required positional argument: 'x'"):
v.validate_python('wrong')
assert v.validate_python('wrong') == 7


def test_typed_dict_error():
def test_typed_dict():
v = SchemaValidator(
{
'type': 'typed-dict',
'fields': {
'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}},
'y': {
'type': 'typed-dict-field',
'schema': {'type': 'default', 'schema': {'type': 'str'}, 'default_factory': lambda y: y * 2},
'schema': {
'type': 'default',
'schema': {'type': 'str'},
'default_factory': lambda v_data: v_data['x'] + ' and y',
},
},
},
}
)
assert v.validate_python({'x': 'x', 'y': 'y'}) == {'x': 'x', 'y': 'y'}
with pytest.raises(TypeError, match=r"<lambda>\(\) missing 1 required positional argument: 'y'"):
v.validate_python({'x': 'x'})
assert v.validate_python({'x': 'x value'}) == {'x': 'x value', 'y': 'x value and y'}


def test_on_error_default_not_int():
Expand Down Expand Up @@ -812,3 +814,32 @@ def _raise(ex: Exception) -> None:
v.validate_python(input_value)

assert exc_info.value.errors(include_url=False, include_context=False) == expected


def test_typed_dict_default_factory():
v = SchemaValidator(
{
'type': 'typed-dict',
'fields': {
'x': {'type': 'typed-dict-field', 'schema': {'type': 'str'}},
'y': {
'type': 'typed-dict-field',
'schema': {
'type': 'default',
'schema': {'type': 'str'},
'default_factory': lambda v_data: v_data['x'] + ' and y',
},
},
'z': {
'type': 'typed-dict-field',
'schema': {
'type': 'default',
'schema': {'type': 'str'},
'default_factory': lambda v_data: v_data['y'] + ' and z',
},
},
},
}
)
assert v.validate_python({'x': 'x', 'y': 'y', 'z': 'z'}) == {'x': 'x', 'y': 'y', 'z': 'z'}
assert v.validate_python({'x': 'x'}) == {'x': 'x', 'y': 'x and y', 'z': 'x and y and z'}
Loading