Skip to content

Commit

Permalink
implement suggestions
Browse files Browse the repository at this point in the history
  • Loading branch information
samuelcolvin committed Jan 16, 2024
1 parent 0797e59 commit 6740e73
Show file tree
Hide file tree
Showing 4 changed files with 52 additions and 140 deletions.
6 changes: 3 additions & 3 deletions src/serializers/fields.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ use super::extra::Extra;
use super::filter::SchemaFilter;
use super::infer::{infer_json_key, infer_serialize, infer_to_python, SerializeInfer};
use super::shared::PydanticSerializer;
use super::shared::{CombinedSerializer, DictIterator, TypeSerializer};
use super::shared::{CombinedSerializer, TypeSerializer};

/// representation of a field for serialization
#[derive(Debug, Clone)]
Expand Down Expand Up @@ -321,7 +321,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
return infer_to_python(value, include, exclude, &td_extra);
};

let output_dict = self.main_to_python(py, DictIterator::new(main_dict), include, exclude, td_extra)?;
let output_dict = self.main_to_python(py, main_dict.iter().map(Ok), include, exclude, td_extra)?;

// this is used to include `__pydantic_extra__` in serialization on models
if let Some(extra_dict) = extra_dict {
Expand Down Expand Up @@ -373,7 +373,7 @@ impl TypeSerializer for GeneralFieldsSerializer {
// NOTE! As above, we maintain the order of the input dict assuming that's right
// we don't both with `used_fields` here because on unions, `to_python(..., mode='json')` is used
let mut map = self.main_serde_serialize(
DictIterator::new(main_dict),
main_dict.iter().map(Ok),
expected_len,
serializer,
include,
Expand Down
62 changes: 22 additions & 40 deletions src/serializers/infer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ use super::errors::{py_err_se_err, PydanticSerializationError};
use super::extra::{Extra, SerMode};
use super::filter::{AnyFilter, SchemaFilter};
use super::ob_type::ObType;
use super::shared::{AnyDataclassIterator, DictIterator, PydanticSerializer, TypeSerializer};
use super::shared::{any_dataclass_iter, PydanticSerializer, TypeSerializer};
use super::SchemaSerializer;

pub(crate) fn infer_to_python(
Expand Down Expand Up @@ -151,7 +151,10 @@ pub(crate) fn infer_to_python_known(
PyList::new(py, elements).into_py(py)
}
ObType::Dict => {
serialize_pairs_python_mode_json(py, DictIterator::new(value.downcast()?), include, exclude, extra)?
let dict: &PyDict = value.downcast()?;
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, |k| {
Ok(PyString::new(py, &infer_json_key(k, extra)?))
})?
}
ObType::Datetime => {
let py_dt: &PyDateTime = value.downcast()?;
Expand Down Expand Up @@ -190,7 +193,9 @@ pub(crate) fn infer_to_python_known(
}
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => {
serialize_pairs_python_mode_json(py, AnyDataclassIterator::new(value)?, include, exclude, extra)?
serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, |k| {
Ok(PyString::new(py, &infer_json_key(k, extra)?))
})?
}
ObType::Enum => {
let v = value.getattr(intern!(py, "value"))?;
Expand Down Expand Up @@ -241,11 +246,12 @@ pub(crate) fn infer_to_python_known(
let elements = serialize_seq!(PyFrozenSet);
PyFrozenSet::new(py, &elements)?.into_py(py)
}
ObType::Dict => serialize_pairs_python(py, DictIterator::new(value.downcast()?), include, exclude, extra)?,
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => {
serialize_pairs_python(py, AnyDataclassIterator::new(value)?, include, exclude, extra)?
ObType::Dict => {
let dict: &PyDict = value.downcast()?;
serialize_pairs_python(py, dict.iter().map(Ok), include, exclude, extra, Ok)?
}
ObType::PydanticSerializable => serialize_with_serializer()?,
ObType::Dataclass => serialize_pairs_python(py, any_dataclass_iter(value)?.0, include, exclude, extra, Ok)?,
ObType::Generator => {
let iter = super::type_serializers::generator::SerializationIterator::new(
value.downcast()?,
Expand Down Expand Up @@ -404,7 +410,7 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
}
ObType::Dict => {
let dict = value.downcast::<PyDict>().map_err(py_err_se_err)?;
serialize_pairs_json(DictIterator::new(dict), serializer, include, exclude, extra)
serialize_pairs_json(dict.iter().map(Ok), dict.len(), serializer, include, exclude, extra)
}
ObType::List => serialize_seq_filter!(PyList),
ObType::Tuple => serialize_seq_filter!(PyTuple),
Expand Down Expand Up @@ -463,13 +469,10 @@ pub(crate) fn infer_serialize_known<S: Serializer>(
PydanticSerializer::new(value, &extracted_serializer.serializer, include, exclude, &extra);
pydantic_serializer.serialize(serializer)
}
ObType::Dataclass => serialize_pairs_json(
AnyDataclassIterator::new(value).map_err(py_err_se_err)?,
serializer,
include,
exclude,
extra,
),
ObType::Dataclass => {
let (pairs_iter, fields_dict) = any_dataclass_iter(value).map_err(py_err_se_err)?;
serialize_pairs_json(pairs_iter, fields_dict.len(), serializer, include, exclude, extra)
}
ObType::Uuid => {
let py_uuid: &PyAny = value.downcast().map_err(py_err_se_err)?;
let uuid = super::type_serializers::uuid::uuid_to_string(py_uuid).map_err(py_err_se_err)?;
Expand Down Expand Up @@ -645,6 +648,7 @@ fn serialize_pairs_python<'py>(
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
key_transform: impl Fn(&'py PyAny) -> PyResult<&'py PyAny>,
) -> PyResult<PyObject> {
let new_dict = PyDict::new(py);
let filter = AnyFilter::new();
Expand All @@ -653,29 +657,7 @@ fn serialize_pairs_python<'py>(
let (k, v) = result?;
let op_next = filter.key_filter(k, include, exclude)?;
if let Some((next_include, next_exclude)) = op_next {
let v = infer_to_python(v, next_include, next_exclude, extra)?;
new_dict.set_item(k, v)?;
}
}
Ok(new_dict.into_py(py))
}

fn serialize_pairs_python_mode_json<'py>(
py: Python,
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> PyResult<PyObject> {
let new_dict = PyDict::new(py);
let filter = AnyFilter::new();

for result in pairs_iter {
let (k, v) = result?;
let op_next = filter.key_filter(k, include, exclude)?;
if let Some((next_include, next_exclude)) = op_next {
let k_str = infer_json_key(k, extra)?;
let k = PyString::new(py, &k_str);
let k = key_transform(k)?;
let v = infer_to_python(v, next_include, next_exclude, extra)?;
new_dict.set_item(k, v)?;
}
Expand All @@ -685,13 +667,13 @@ fn serialize_pairs_python_mode_json<'py>(

fn serialize_pairs_json<'py, S: Serializer>(
pairs_iter: impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>>,
iter_size: usize,
serializer: S,
include: Option<&PyAny>,
exclude: Option<&PyAny>,
extra: &Extra,
) -> Result<S::Ok, S::Error> {
let (_, expected) = pairs_iter.size_hint();
let mut map = serializer.serialize_map(expected)?;
let mut map = serializer.serialize_map(Some(iter_size))?;
let filter = AnyFilter::new();

for result in pairs_iter {
Expand Down
79 changes: 14 additions & 65 deletions src/serializers/shared.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,6 @@ use std::fmt::Debug;
use pyo3::exceptions::PyTypeError;
use pyo3::once_cell::GILOnceCell;
use pyo3::prelude::*;
use pyo3::types::iter::PyDictIterator;
use pyo3::types::{PyDict, PyString};
use pyo3::{intern, PyTraverseError, PyVisit};

Expand Down Expand Up @@ -365,75 +364,25 @@ pub(crate) fn to_json_bytes(
Ok(bytes)
}

pub(super) struct DictIterator<'py> {
dict_iter: PyDictIterator<'py>,
}

impl<'py> DictIterator<'py> {
pub fn new(dict: &'py PyDict) -> Self {
Self { dict_iter: dict.iter() }
}
}

impl<'py> Iterator for DictIterator<'py> {
type Item = PyResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
self.dict_iter.next().map(Ok)
}

fn size_hint(&self) -> (usize, Option<usize>) {
self.dict_iter.size_hint()
}
}

pub(super) struct AnyDataclassIterator<'py> {
pub(super) fn any_dataclass_iter<'py>(
dataclass: &'py PyAny,
fields_iter: PyDictIterator<'py>,
field_type_marker: &'py PyAny,
}

impl<'py> AnyDataclassIterator<'py> {
pub fn new(dc: &'py PyAny) -> PyResult<Self> {
let py = dc.py();
let fields: &PyDict = dc.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
Ok(Self {
dataclass: dc,
fields_iter: fields.iter(),
field_type_marker: get_field_marker(py)?,
})
}

fn _next(&mut self) -> PyResult<Option<(&'py PyAny, &'py PyAny)>> {
if let Some((field_name, field)) = self.fields_iter.next() {
let field_type = field.getattr(intern!(self.dataclass.py(), "_field_type"))?;
if field_type.is(self.field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
let value = self.dataclass.getattr(field_name)?;
Ok(Some((field_name, value)))
} else {
self._next()
}
) -> PyResult<(impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'py, &PyDict)> {
let py = dataclass.py();
let fields: &PyDict = dataclass.getattr(intern!(py, "__dataclass_fields__"))?.downcast()?;
let field_type_marker = get_field_marker(py)?;

let next = move |(field_name, field): (&'py PyAny, &'py PyAny)| -> PyResult<Option<(&'py PyAny, &'py PyAny)>> {
let field_type = field.getattr(intern!(py, "_field_type"))?;
if field_type.is(field_type_marker) {
let field_name: &PyString = field_name.downcast()?;
let value = dataclass.getattr(field_name)?;
Ok(Some((field_name, value)))
} else {
Ok(None)
}
}
}

impl<'py> Iterator for AnyDataclassIterator<'py> {
type Item = PyResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
match self._next() {
Ok(Some(v)) => Some(Ok(v)),
Ok(None) => None,
Err(e) => Some(Err(e)),
}
}
};

fn size_hint(&self) -> (usize, Option<usize>) {
(0, None)
}
Ok((fields.iter().filter_map(move |field| next(field).transpose()), fields))
}

static DC_FIELD_MARKER: GILOnceCell<PyObject> = GILOnceCell::new();
Expand Down
45 changes: 13 additions & 32 deletions src/serializers/type_serializers/dataclass.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ impl TypeSerializer for DataclassSerializer {
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let output_dict = fields_serializer.main_to_python(
py,
KnownDataclassIterator::new(&self.fields, value),
known_dataclass_iter(&self.fields, value),
include,
exclude,
dc_extra,
Expand Down Expand Up @@ -182,7 +182,7 @@ impl TypeSerializer for DataclassSerializer {
if let CombinedSerializer::Fields(ref fields_serializer) = *self.serializer {
let expected_len = self.fields.len() + fields_serializer.computed_field_count();
let mut map = fields_serializer.main_serde_serialize(
KnownDataclassIterator::new(&self.fields, value),
known_dataclass_iter(&self.fields, value),
expected_len,
serializer,
include,
Expand Down Expand Up @@ -211,36 +211,17 @@ impl TypeSerializer for DataclassSerializer {
}
}

pub struct KnownDataclassIterator<'a, 'py> {
index: usize,
fn known_dataclass_iter<'a, 'py>(
fields: &'a [Py<PyString>],
dataclass: &'py PyAny,
}

impl<'a, 'py> KnownDataclassIterator<'a, 'py> {
pub fn new(fields: &'a [Py<PyString>], dataclass: &'py PyAny) -> Self {
Self {
index: 0,
fields,
dataclass,
}
}
}

impl<'a, 'py> Iterator for KnownDataclassIterator<'a, 'py> {
type Item = PyResult<(&'py PyAny, &'py PyAny)>;

fn next(&mut self) -> Option<Self::Item> {
if let Some(field) = self.fields.get(self.index) {
self.index += 1;
let py = self.dataclass.py();
let field_ref = field.clone_ref(py).into_ref(py);
match self.dataclass.getattr(field_ref) {
Ok(value) => Some(Ok((field_ref, value))),
Err(e) => Some(Err(e)),
}
} else {
None
}
}
) -> impl Iterator<Item = PyResult<(&'py PyAny, &'py PyAny)>> + 'a
where
'py: 'a,
{
let py = dataclass.py();
fields.iter().map(move |field| {
let field_ref = field.clone_ref(py).into_ref(py);
let value = dataclass.getattr(field_ref)?;
Ok((field_ref as &PyAny, value))
})
}

0 comments on commit 6740e73

Please sign in to comment.