Skip to content

Commit

Permalink
Simplify with SerializerState
Browse files Browse the repository at this point in the history
  • Loading branch information
ijl committed Jan 18, 2024
1 parent 5205258 commit a40f58b
Show file tree
Hide file tree
Showing 15 changed files with 327 additions and 494 deletions.
2 changes: 1 addition & 1 deletion src/opt.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

pub type Opt = u16;
pub type Opt = u32;

pub const INDENT_2: Opt = 1;
pub const NAIVE_UTC: Opt = 1 << 1;
Expand Down
2 changes: 2 additions & 0 deletions src/serialize/mod.rs
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

mod error;
mod obtype;
mod per_type;
mod serializer;
mod state;
mod writer;

pub use serializer::serialize;
108 changes: 108 additions & 0 deletions src/serialize/obtype.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,108 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::opt::{
Opt, PASSTHROUGH_DATACLASS, PASSTHROUGH_DATETIME, PASSTHROUGH_SUBCLASS, SERIALIZE_NUMPY,
};
use crate::serialize::per_type::{is_numpy_array, is_numpy_scalar};
use crate::typeref::{
BOOL_TYPE, DATACLASS_FIELDS_STR, DATETIME_TYPE, DATE_TYPE, DICT_TYPE, ENUM_TYPE, FLOAT_TYPE,
FRAGMENT_TYPE, INT_TYPE, LIST_TYPE, NONE_TYPE, STR_TYPE, TIME_TYPE, TUPLE_TYPE, UUID_TYPE,
};

#[repr(u32)]
pub enum ObType {
Str,
Int,
Bool,
None,
Float,
List,
Dict,
Datetime,
Date,
Time,
Tuple,
Uuid,
Dataclass,
NumpyScalar,
NumpyArray,
Enum,
StrSubclass,
Fragment,
Unknown,
}

pub fn pyobject_to_obtype(obj: *mut pyo3_ffi::PyObject, opts: Opt) -> ObType {
let ob_type = ob_type!(obj);
if is_class_by_type!(ob_type, STR_TYPE) {
ObType::Str
} else if is_class_by_type!(ob_type, INT_TYPE) {
ObType::Int
} else if is_class_by_type!(ob_type, BOOL_TYPE) {
ObType::Bool
} else if is_class_by_type!(ob_type, NONE_TYPE) {
ObType::None
} else if is_class_by_type!(ob_type, FLOAT_TYPE) {
ObType::Float
} else if is_class_by_type!(ob_type, LIST_TYPE) {
ObType::List
} else if is_class_by_type!(ob_type, DICT_TYPE) {
ObType::Dict
} else if is_class_by_type!(ob_type, DATETIME_TYPE) && opt_disabled!(opts, PASSTHROUGH_DATETIME)
{
ObType::Datetime
} else {
pyobject_to_obtype_unlikely(ob_type, opts)
}
}

#[cfg_attr(feature = "optimize", optimize(size))]
#[inline(never)]
pub fn pyobject_to_obtype_unlikely(ob_type: *mut pyo3_ffi::PyTypeObject, opts: Opt) -> ObType {
if is_class_by_type!(ob_type, UUID_TYPE) {
return ObType::Uuid;
} else if is_class_by_type!(ob_type, TUPLE_TYPE) {
return ObType::Tuple;
} else if is_class_by_type!(ob_type, FRAGMENT_TYPE) {
return ObType::Fragment;
}

if opt_disabled!(opts, PASSTHROUGH_DATETIME) {
if is_class_by_type!(ob_type, DATE_TYPE) {
return ObType::Date;
} else if is_class_by_type!(ob_type, TIME_TYPE) {
return ObType::Time;
}
}

if opt_disabled!(opts, PASSTHROUGH_SUBCLASS) {
if is_subclass_by_flag!(ob_type, Py_TPFLAGS_UNICODE_SUBCLASS) {
return ObType::StrSubclass;
} else if is_subclass_by_flag!(ob_type, Py_TPFLAGS_LONG_SUBCLASS) {
return ObType::Int;
} else if is_subclass_by_flag!(ob_type, Py_TPFLAGS_LIST_SUBCLASS) {
return ObType::List;
} else if is_subclass_by_flag!(ob_type, Py_TPFLAGS_DICT_SUBCLASS) {
return ObType::Dict;
}
}

if is_subclass_by_type!(ob_type, ENUM_TYPE) {
return ObType::Enum;
}

if opt_disabled!(opts, PASSTHROUGH_DATACLASS) && pydict_contains!(ob_type, DATACLASS_FIELDS_STR)
{
return ObType::Dataclass;
}

if unlikely!(opt_enabled!(opts, SERIALIZE_NUMPY)) {
if is_numpy_scalar(ob_type) {
return ObType::NumpyScalar;
} else if is_numpy_array(ob_type) {
return ObType::NumpyArray;
}
}

ObType::Unknown
}
110 changes: 32 additions & 78 deletions src/serialize/per_type/dataclass.rs
Original file line number Diff line number Diff line change
@@ -1,82 +1,60 @@
// SPDX-License-Identifier: (Apache-2.0 OR MIT)

use crate::opt::*;
use crate::serialize::error::SerializeError;
use crate::serialize::serializer::{PyObjectSerializer, RECURSION_LIMIT};
use crate::serialize::serializer::PyObjectSerializer;
use crate::serialize::state::SerializerState;
use crate::str::unicode_to_str;
use crate::typeref::*;
use crate::typeref::{
DATACLASS_FIELDS_STR, DICT_STR, FIELD_TYPE, FIELD_TYPE_STR, SLOTS_STR, STR_TYPE,
};

use serde::ser::{Serialize, SerializeMap, Serializer};

use std::ptr::NonNull;

pub struct DataclassGenericSerializer {
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3_ffi::PyObject>>,
#[repr(transparent)]
pub struct DataclassGenericSerializer<'a> {
previous: &'a PyObjectSerializer,
}

impl DataclassGenericSerializer {
pub fn new(
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
default: Option<NonNull<pyo3_ffi::PyObject>>,
) -> Self {
DataclassGenericSerializer {
ptr: ptr,
opts: opts,
default_calls: default_calls,
recursion: recursion + 1,
default: default,
}
impl<'a> DataclassGenericSerializer<'a> {
pub fn new(previous: &'a PyObjectSerializer) -> Self {
Self { previous: previous }
}
}

impl Serialize for DataclassGenericSerializer {
impl<'a> Serialize for DataclassGenericSerializer<'a> {
#[inline(never)]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
if unlikely!(self.recursion == RECURSION_LIMIT) {
if unlikely!(self.previous.state.recursion_limit()) {
err!(SerializeError::RecursionLimit)
}
let dict = ffi!(PyObject_GetAttr(self.ptr, DICT_STR));
let ob_type = ob_type!(self.ptr);
let dict = ffi!(PyObject_GetAttr(self.previous.ptr, DICT_STR));
let ob_type = ob_type!(self.previous.ptr);
if unlikely!(dict.is_null()) {
ffi!(PyErr_Clear());
DataclassFallbackSerializer::new(
self.ptr,
self.opts,
self.default_calls,
self.recursion,
self.default,
self.previous.ptr,
self.previous.state,
self.previous.default,
)
.serialize(serializer)
} else if pydict_contains!(ob_type, SLOTS_STR) {
let ret = DataclassFallbackSerializer::new(
self.ptr,
self.opts,
self.default_calls,
self.recursion,
self.default,
self.previous.ptr,
self.previous.state,
self.previous.default,
)
.serialize(serializer);
ffi!(Py_DECREF(dict));
ret
} else {
let ret = DataclassFastSerializer::new(
dict,
self.opts,
self.default_calls,
self.recursion,
self.default,
)
.serialize(serializer);
let ret =
DataclassFastSerializer::new(dict, self.previous.state, self.previous.default)
.serialize(serializer);
ffi!(Py_DECREF(dict));
ret
}
Expand All @@ -85,25 +63,19 @@ impl Serialize for DataclassGenericSerializer {

pub struct DataclassFastSerializer {
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
state: SerializerState,
default: Option<NonNull<pyo3_ffi::PyObject>>,
}

impl DataclassFastSerializer {
pub fn new(
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
state: SerializerState,
default: Option<NonNull<pyo3_ffi::PyObject>>,
) -> Self {
DataclassFastSerializer {
ptr: ptr,
opts: opts,
default_calls: default_calls,
recursion: recursion,
state: state.copy_for_recursive_call(),
default: default,
}
}
Expand Down Expand Up @@ -142,13 +114,7 @@ impl Serialize for DataclassFastSerializer {
if unlikely!(key_as_str.as_bytes()[0] == b'_') {
continue;
}
let pyvalue = PyObjectSerializer::new(
value,
self.opts,
self.default_calls,
self.recursion,
self.default,
);
let pyvalue = PyObjectSerializer::new(value, self.state, self.default);
map.serialize_key(key_as_str).unwrap();
map.serialize_value(&pyvalue)?;
}
Expand All @@ -158,25 +124,19 @@ impl Serialize for DataclassFastSerializer {

pub struct DataclassFallbackSerializer {
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
state: SerializerState,
default: Option<NonNull<pyo3_ffi::PyObject>>,
}

impl DataclassFallbackSerializer {
pub fn new(
ptr: *mut pyo3_ffi::PyObject,
opts: Opt,
default_calls: u8,
recursion: u8,
state: SerializerState,
default: Option<NonNull<pyo3_ffi::PyObject>>,
) -> Self {
DataclassFallbackSerializer {
ptr: ptr,
opts: opts,
default_calls: default_calls,
recursion: recursion,
state: state.copy_for_recursive_call(),
default: default,
}
}
Expand Down Expand Up @@ -226,13 +186,7 @@ impl Serialize for DataclassFallbackSerializer {
let value = ffi!(PyObject_GetAttr(self.ptr, attr));
debug_assert!(ffi!(Py_REFCNT(value)) >= 2);
ffi!(Py_DECREF(value));
let pyvalue = PyObjectSerializer::new(
value,
self.opts,
self.default_calls,
self.recursion,
self.default,
);
let pyvalue = PyObjectSerializer::new(value, self.state, self.default);

map.serialize_key(key_as_str).unwrap();
map.serialize_value(&pyvalue)?
Expand Down
6 changes: 4 additions & 2 deletions src/serialize/per_type/datetime.rs
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ impl Date {
pub fn new(ptr: *mut pyo3_ffi::PyObject) -> Self {
Date { ptr: ptr }
}
#[cfg_attr(feature = "optimize", optimize(size))]

pub fn write_buf(&self, buf: &mut DateTimeBuffer) {
{
let year = ffi!(PyDateTime_GET_YEAR(self.ptr));
Expand All @@ -66,6 +66,7 @@ impl Date {
}
}
impl Serialize for Date {
#[cold]
#[inline(never)]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down Expand Up @@ -93,7 +94,7 @@ impl Time {
opts: opts,
}
}
#[cfg_attr(feature = "optimize", optimize(size))]

pub fn write_buf(&self, buf: &mut DateTimeBuffer) -> Result<(), TimeError> {
if unsafe { (*(self.ptr as *mut pyo3_ffi::PyDateTime_Time)).hastzinfo == 1 } {
return Err(TimeError::HasTimezone);
Expand All @@ -115,6 +116,7 @@ impl Time {
}

impl Serialize for Time {
#[cold]
#[inline(never)]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
Expand Down
Loading

0 comments on commit a40f58b

Please sign in to comment.