Skip to content

Commit

Permalink
implement _serializable_encoders dictionary
Browse files Browse the repository at this point in the history
  • Loading branch information
gshank committed Feb 3, 2021
1 parent 461714d commit c25ed07
Show file tree
Hide file tree
Showing 3 changed files with 19 additions and 1 deletion.
5 changes: 4 additions & 1 deletion mashumaro/serializer/base/dict.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
from typing import Any, Dict, Mapping, Type, TypeVar, Optional
from typing import Any, Dict, Mapping, Type, TypeVar, Optional, ClassVar

from mashumaro.serializer.base.metaprogramming import CodeBuilder
from mashumaro.types import SerializableEncoder

T = TypeVar("T", bound="DataClassDictMixin")


class DataClassDictMixin:
_serializable_encoders: ClassVar[Dict[Type, SerializableEncoder]] = {}

def __init_subclass__(cls: Type[T], **kwargs):
builder = CodeBuilder(cls)
exc = None
Expand Down
4 changes: 4 additions & 0 deletions mashumaro/serializer/base/metaprogramming.py
Original file line number Diff line number Diff line change
Expand Up @@ -385,6 +385,8 @@ def _pack_value(
f"self.__dataclass_fields__['{fname}'].type"
f"._serialize({value_name})"
)
if type_name(ftype) in self.cls._serializable_encoders:
return f"self._serializable_encoders['{type_name(ftype)}']._serialize({value_name})"

origin_type = get_type_origin(ftype)
if is_special_typing_primitive(origin_type):
Expand Down Expand Up @@ -586,6 +588,8 @@ def _unpack_field_value(
f"cls.__dataclass_fields__['{fname}'].type"
f"._deserialize({value_name})"
)
if type_name(ftype) in self.cls._serializable_encoders:
return f"cls._serializable_encoders['{type_name(ftype)}']._deserialize({value_name})"

origin_type = get_type_origin(ftype)
if is_special_typing_primitive(origin_type):
Expand Down
11 changes: 11 additions & 0 deletions mashumaro/types.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,16 @@
import decimal
from typing import TypeVar, Generic

TV = TypeVar("TV")

class SerializableEncoder(Generic[TV]):
@classmethod
def _serialize(cls, value):
raise NotImplementedError

@classmethod
def _deserialize(cls, value):
raise NotImplementedError

class SerializableType:
def _serialize(self):
Expand Down

0 comments on commit c25ed07

Please sign in to comment.