400 lines
16 KiB
Python
400 lines
16 KiB
Python
# flake8: noqa
|
|
|
|
import typing
|
|
import warnings
|
|
import sys
|
|
from copy import deepcopy
|
|
|
|
from dataclasses import MISSING, is_dataclass, fields as dc_fields
|
|
from datetime import datetime
|
|
from decimal import Decimal
|
|
from uuid import UUID
|
|
from enum import Enum
|
|
|
|
from typing_inspect import is_union_type # type: ignore
|
|
|
|
from marshmallow import fields, Schema, post_load # type: ignore
|
|
from marshmallow.exceptions import ValidationError # type: ignore
|
|
|
|
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
|
|
_ExtendedEncoder, _user_overrides_or_exts)
|
|
from dataclasses_json.utils import (_is_collection, _is_optional,
|
|
_issubclass_safe, _timestamp_to_dt_aware,
|
|
_is_new_type, _get_type_origin,
|
|
_handle_undefined_parameters_safe,
|
|
CatchAllVar)
|
|
|
|
|
|
class _TimestampField(fields.Field):
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
if value is not None:
|
|
return value.timestamp()
|
|
else:
|
|
if not self.required:
|
|
return None
|
|
else:
|
|
raise ValidationError(self.default_error_messages["required"])
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
if value is not None:
|
|
return _timestamp_to_dt_aware(value)
|
|
else:
|
|
if not self.required:
|
|
return None
|
|
else:
|
|
raise ValidationError(self.default_error_messages["required"])
|
|
|
|
|
|
class _IsoField(fields.Field):
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
if value is not None:
|
|
return value.isoformat()
|
|
else:
|
|
if not self.required:
|
|
return None
|
|
else:
|
|
raise ValidationError(self.default_error_messages["required"])
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
if value is not None:
|
|
return datetime.fromisoformat(value)
|
|
else:
|
|
if not self.required:
|
|
return None
|
|
else:
|
|
raise ValidationError(self.default_error_messages["required"])
|
|
|
|
|
|
class _UnionField(fields.Field):
|
|
def __init__(self, desc, cls, field, *args, **kwargs):
|
|
self.desc = desc
|
|
self.cls = cls
|
|
self.field = field
|
|
super().__init__(*args, **kwargs)
|
|
|
|
def _serialize(self, value, attr, obj, **kwargs):
|
|
if self.allow_none and value is None:
|
|
return None
|
|
for type_, schema_ in self.desc.items():
|
|
if _issubclass_safe(type(value), type_):
|
|
if is_dataclass(value):
|
|
res = schema_._serialize(value, attr, obj, **kwargs)
|
|
res['__type'] = str(type_.__name__)
|
|
return res
|
|
break
|
|
elif isinstance(value, _get_type_origin(type_)):
|
|
return schema_._serialize(value, attr, obj, **kwargs)
|
|
else:
|
|
warnings.warn(
|
|
f'The type "{type(value).__name__}" (value: "{value}") '
|
|
f'is not in the list of possible types of typing.Union '
|
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
|
|
f'Value cannot be serialized properly.')
|
|
return super()._serialize(value, attr, obj, **kwargs)
|
|
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
tmp_value = deepcopy(value)
|
|
if isinstance(tmp_value, dict) and '__type' in tmp_value:
|
|
dc_name = tmp_value['__type']
|
|
for type_, schema_ in self.desc.items():
|
|
if is_dataclass(type_) and type_.__name__ == dc_name:
|
|
del tmp_value['__type']
|
|
return schema_._deserialize(tmp_value, attr, data, **kwargs)
|
|
elif isinstance(tmp_value, dict):
|
|
warnings.warn(
|
|
f'Attempting to deserialize "dict" (value: "{tmp_value}) '
|
|
f'that does not have a "__type" type specifier field into'
|
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).'
|
|
f'Deserialization may fail, or deserialization to wrong type may occur.'
|
|
)
|
|
return super()._deserialize(tmp_value, attr, data, **kwargs)
|
|
else:
|
|
for type_, schema_ in self.desc.items():
|
|
if isinstance(tmp_value, _get_type_origin(type_)):
|
|
return schema_._deserialize(tmp_value, attr, data, **kwargs)
|
|
else:
|
|
warnings.warn(
|
|
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
|
|
f'is not in the list of possible types of typing.Union '
|
|
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
|
|
f'Value cannot be deserialized properly.')
|
|
return super()._deserialize(tmp_value, attr, data, **kwargs)
|
|
|
|
|
|
class _TupleVarLen(fields.List):
|
|
"""
|
|
variable-length homogeneous tuples
|
|
"""
|
|
def _deserialize(self, value, attr, data, **kwargs):
|
|
optional_list = super()._deserialize(value, attr, data, **kwargs)
|
|
return None if optional_list is None else tuple(optional_list)
|
|
|
|
|
|
TYPES = {
|
|
typing.Mapping: fields.Mapping,
|
|
typing.MutableMapping: fields.Mapping,
|
|
typing.List: fields.List,
|
|
typing.Dict: fields.Dict,
|
|
typing.Tuple: fields.Tuple,
|
|
typing.Callable: fields.Function,
|
|
typing.Any: fields.Raw,
|
|
dict: fields.Dict,
|
|
list: fields.List,
|
|
tuple: fields.Tuple,
|
|
str: fields.Str,
|
|
int: fields.Int,
|
|
float: fields.Float,
|
|
bool: fields.Bool,
|
|
datetime: _TimestampField,
|
|
UUID: fields.UUID,
|
|
Decimal: fields.Decimal,
|
|
CatchAllVar: fields.Dict,
|
|
}
|
|
|
|
A = typing.TypeVar('A')
|
|
JsonData = typing.Union[str, bytes, bytearray]
|
|
TEncoded = typing.Dict[str, typing.Any]
|
|
TOneOrMulti = typing.Union[typing.List[A], A]
|
|
TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded]
|
|
|
|
if sys.version_info >= (3, 7) or typing.TYPE_CHECKING:
|
|
class SchemaF(Schema, typing.Generic[A]):
|
|
"""Lift Schema into a type constructor"""
|
|
|
|
def __init__(self, *args, **kwargs):
|
|
"""
|
|
Raises exception because this class should not be inherited.
|
|
This class is helper only.
|
|
"""
|
|
|
|
super().__init__(*args, **kwargs)
|
|
raise NotImplementedError()
|
|
|
|
@typing.overload
|
|
def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: # type: ignore
|
|
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
|
|
pass
|
|
|
|
@typing.overload
|
|
def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded:
|
|
pass
|
|
|
|
def dump(self, obj: TOneOrMulti, # type: ignore
|
|
many: typing.Optional[bool] = None) -> TOneOrMultiEncoded:
|
|
pass
|
|
|
|
@typing.overload
|
|
def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args,
|
|
**kwargs) -> str:
|
|
pass
|
|
|
|
@typing.overload
|
|
def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str:
|
|
pass
|
|
|
|
def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, # type: ignore
|
|
**kwargs) -> str:
|
|
pass
|
|
|
|
@typing.overload # type: ignore
|
|
def load(self, data: typing.List[TEncoded],
|
|
many: bool = True, partial: typing.Optional[bool] = None,
|
|
unknown: typing.Optional[str] = None) -> \
|
|
typing.List[A]:
|
|
# ignore the mypy error of the decorator because mm does not define lists as an allowed input type
|
|
pass
|
|
|
|
@typing.overload
|
|
def load(self, data: TEncoded,
|
|
many: None = None, partial: typing.Optional[bool] = None,
|
|
unknown: typing.Optional[str] = None) -> A:
|
|
pass
|
|
|
|
def load(self, data: TOneOrMultiEncoded,
|
|
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None,
|
|
unknown: typing.Optional[str] = None) -> TOneOrMulti:
|
|
pass
|
|
|
|
@typing.overload # type: ignore
|
|
def loads(self, json_data: JsonData, # type: ignore
|
|
many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
|
|
**kwargs) -> typing.List[A]:
|
|
# ignore the mypy error of the decorator because mm does not define bytes as correct input data
|
|
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
|
|
# for the return type overlap
|
|
pass
|
|
|
|
def loads(self, json_data: JsonData,
|
|
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
|
|
**kwargs) -> TOneOrMulti:
|
|
pass
|
|
|
|
|
|
SchemaType = SchemaF[A]
|
|
else:
|
|
SchemaType = Schema
|
|
|
|
|
|
def build_type(type_, options, mixin, field, cls):
|
|
def inner(type_, options):
|
|
while True:
|
|
if not _is_new_type(type_):
|
|
break
|
|
|
|
type_ = type_.__supertype__
|
|
|
|
if is_dataclass(type_):
|
|
if _issubclass_safe(type_, mixin):
|
|
options['field_many'] = bool(
|
|
_is_supported_generic(field.type) and _is_collection(
|
|
field.type))
|
|
return fields.Nested(type_.schema(), **options)
|
|
else:
|
|
warnings.warn(f"Nested dataclass field {field.name} of type "
|
|
f"{field.type} detected in "
|
|
f"{cls.__name__} that is not an instance of "
|
|
f"dataclass_json. Did you mean to recursively "
|
|
f"serialize this field? If so, make sure to "
|
|
f"augment {type_} with either the "
|
|
f"`dataclass_json` decorator or mixin.")
|
|
return fields.Field(**options)
|
|
|
|
origin = getattr(type_, '__origin__', type_)
|
|
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
|
|
a is not type(None)]
|
|
|
|
if type_ == Ellipsis:
|
|
return type_
|
|
|
|
if _is_optional(type_):
|
|
options["allow_none"] = True
|
|
if origin is tuple:
|
|
if len(args) == 2 and args[1] == Ellipsis:
|
|
return _TupleVarLen(args[0], **options)
|
|
else:
|
|
return fields.Tuple(args, **options)
|
|
if origin in TYPES:
|
|
return TYPES[origin](*args, **options)
|
|
|
|
if _issubclass_safe(origin, Enum):
|
|
return fields.Enum(enum=origin, by_value=True, *args, **options)
|
|
|
|
if is_union_type(type_):
|
|
union_types = [a for a in getattr(type_, '__args__', []) if
|
|
a is not type(None)]
|
|
union_desc = dict(zip(union_types, args))
|
|
return _UnionField(union_desc, cls, field, **options)
|
|
|
|
warnings.warn(
|
|
f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} "
|
|
f"It's advised to pass the correct marshmallow type to `mm_field`.")
|
|
return fields.Field(**options)
|
|
|
|
return inner(type_, options)
|
|
|
|
|
|
def schema(cls, mixin, infer_missing):
|
|
schema = {}
|
|
overrides = _user_overrides_or_exts(cls)
|
|
# TODO check the undefined parameters and add the proper schema action
|
|
# https://marshmallow.readthedocs.io/en/stable/quickstart.html
|
|
for field in dc_fields(cls):
|
|
metadata = overrides[field.name]
|
|
if metadata.mm_field is not None:
|
|
schema[field.name] = metadata.mm_field
|
|
else:
|
|
type_ = field.type
|
|
options: typing.Dict[str, typing.Any] = {}
|
|
missing_key = 'missing' if infer_missing else 'default'
|
|
if field.default is not MISSING:
|
|
options[missing_key] = field.default
|
|
elif field.default_factory is not MISSING:
|
|
options[missing_key] = field.default_factory()
|
|
else:
|
|
options['required'] = True
|
|
|
|
if options.get(missing_key, ...) is None:
|
|
options['allow_none'] = True
|
|
|
|
if _is_optional(type_):
|
|
options.setdefault(missing_key, None)
|
|
options['allow_none'] = True
|
|
if len(type_.__args__) == 2:
|
|
# Union[str, int, None] is optional too, but it has more than 1 typed field.
|
|
type_ = [tp for tp in type_.__args__ if tp is not type(None)][0]
|
|
|
|
if metadata.letter_case is not None:
|
|
options['data_key'] = metadata.letter_case(field.name)
|
|
|
|
t = build_type(type_, options, mixin, field, cls)
|
|
if field.metadata.get('dataclasses_json', {}).get('decoder'):
|
|
# If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion
|
|
# logic.
|
|
# From Marshmallow's documentation for the _deserialize method:
|
|
# "Deserialize value. Concrete :class:`Field` classes should implement this method. "
|
|
# This is the method that Field implementations override to perform the actual deserialization logic.
|
|
# In this case we specifically override this method instead of `deserialize` to minimize potential
|
|
# side effects, and only cancel the actual value deserialization.
|
|
t._deserialize = lambda v, *_a, **_kw: v
|
|
|
|
# if type(t) is not fields.Field: # If we use `isinstance` we would return nothing.
|
|
if field.type != typing.Optional[CatchAllVar]:
|
|
schema[field.name] = t
|
|
|
|
return schema
|
|
|
|
|
|
def build_schema(cls: typing.Type[A],
|
|
mixin,
|
|
infer_missing,
|
|
partial) -> typing.Type["SchemaType[A]"]:
|
|
Meta = type('Meta',
|
|
(),
|
|
{'fields': tuple(field.name for field in dc_fields(cls) # type: ignore
|
|
if
|
|
field.name != 'dataclass_json_config' and field.type !=
|
|
typing.Optional[CatchAllVar]),
|
|
# TODO #180
|
|
# 'render_module': global_config.json_module
|
|
})
|
|
|
|
@post_load
|
|
def make_instance(self, kvs, **kwargs):
|
|
return _decode_dataclass(cls, kvs, partial)
|
|
|
|
def dumps(self, *args, **kwargs):
|
|
if 'cls' not in kwargs:
|
|
kwargs['cls'] = _ExtendedEncoder
|
|
|
|
return Schema.dumps(self, *args, **kwargs)
|
|
|
|
def dump(self, obj, *, many=None):
|
|
many = self.many if many is None else bool(many)
|
|
dumped = Schema.dump(self, obj, many=many)
|
|
# TODO This is hacky, but the other option I can think of is to generate a different schema
|
|
# depending on dump and load, which is even more hacky
|
|
|
|
# The only problem is the catch-all field, we can't statically create a schema for it,
|
|
# so we just update the dumped dict
|
|
if many:
|
|
for i, _obj in enumerate(obj):
|
|
dumped[i].update(
|
|
_handle_undefined_parameters_safe(cls=_obj, kvs={},
|
|
usage="dump"))
|
|
else:
|
|
dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={},
|
|
usage="dump"))
|
|
return dumped
|
|
|
|
schema_ = schema(cls, mixin, infer_missing)
|
|
DataClassSchema: typing.Type["SchemaType[A]"] = type(
|
|
f'{cls.__name__.capitalize()}Schema',
|
|
(Schema,),
|
|
{'Meta': Meta,
|
|
f'make_{cls.__name__.lower()}': make_instance,
|
|
'dumps': dumps,
|
|
'dump': dump,
|
|
**schema_})
|
|
|
|
return DataClassSchema
|