second commit

This commit is contained in:
2024-12-27 22:31:23 +09:00
parent 2353324570
commit 10a0f110ca
8819 changed files with 1307198 additions and 28 deletions

View File

@ -0,0 +1,12 @@
# flake8: noqa
from dataclasses_json.api import (DataClassJsonMixin,
dataclass_json)
from dataclasses_json.cfg import (config, global_config,
Exclude, LetterCase)
from dataclasses_json.undefined import CatchAll, Undefined
from dataclasses_json.__version__ import __version__
__all__ = ['DataClassJsonMixin', 'LetterCase', 'dataclass_json',
'config', 'global_config', 'Exclude',
'CatchAll', 'Undefined']

View File

@ -0,0 +1,6 @@
"""
Version file.
Allows common version lookup via from dataclasses_json import __version__
"""
__version__ = "0.6.7" # replaced by git tag on deploy

View File

@ -0,0 +1,153 @@
import abc
import json
from typing import Any, Callable, Dict, List, Optional, Tuple, Type, TypeVar, Union, overload
from dataclasses_json.cfg import config, LetterCase
from dataclasses_json.core import (Json, _ExtendedEncoder, _asdict,
_decode_dataclass)
from dataclasses_json.mm import (JsonData, SchemaType, build_schema)
from dataclasses_json.undefined import Undefined
from dataclasses_json.utils import (_handle_undefined_parameters_safe,
_undefined_parameter_action_safe)
A = TypeVar('A', bound="DataClassJsonMixin")
T = TypeVar('T')
Fields = List[Tuple[str, Any]]
class DataClassJsonMixin(abc.ABC):
"""
DataClassJsonMixin is an ABC that functions as a Mixin.
As with other ABCs, it should not be instantiated directly.
"""
dataclass_json_config: Optional[dict] = None
def to_json(self,
*,
skipkeys: bool = False,
ensure_ascii: bool = True,
check_circular: bool = True,
allow_nan: bool = True,
indent: Optional[Union[int, str]] = None,
separators: Optional[Tuple[str, str]] = None,
default: Optional[Callable] = None,
sort_keys: bool = False,
**kw) -> str:
return json.dumps(self.to_dict(encode_json=False),
cls=_ExtendedEncoder,
skipkeys=skipkeys,
ensure_ascii=ensure_ascii,
check_circular=check_circular,
allow_nan=allow_nan,
indent=indent,
separators=separators,
default=default,
sort_keys=sort_keys,
**kw)
@classmethod
def from_json(cls: Type[A],
s: JsonData,
*,
parse_float=None,
parse_int=None,
parse_constant=None,
infer_missing=False,
**kw) -> A:
kvs = json.loads(s,
parse_float=parse_float,
parse_int=parse_int,
parse_constant=parse_constant,
**kw)
return cls.from_dict(kvs, infer_missing=infer_missing)
@classmethod
def from_dict(cls: Type[A],
kvs: Json,
*,
infer_missing=False) -> A:
return _decode_dataclass(cls, kvs, infer_missing)
def to_dict(self, encode_json=False) -> Dict[str, Json]:
return _asdict(self, encode_json=encode_json)
@classmethod
def schema(cls: Type[A],
*,
infer_missing: bool = False,
only=None,
exclude=(),
many: bool = False,
context=None,
load_only=(),
dump_only=(),
partial: bool = False,
unknown=None) -> "SchemaType[A]":
Schema = build_schema(cls, DataClassJsonMixin, infer_missing, partial)
if unknown is None:
undefined_parameter_action = _undefined_parameter_action_safe(cls)
if undefined_parameter_action is not None:
# We can just make use of the same-named mm keywords
unknown = undefined_parameter_action.name.lower()
return Schema(only=only,
exclude=exclude,
many=many,
context=context,
load_only=load_only,
dump_only=dump_only,
partial=partial,
unknown=unknown)
@overload
def dataclass_json(_cls: None = ..., *, letter_case: Optional[LetterCase] = ...,
undefined: Optional[Union[str, Undefined]] = ...) -> Callable[[Type[T]], Type[T]]: ...
@overload
def dataclass_json(_cls: Type[T], *, letter_case: Optional[LetterCase] = ...,
undefined: Optional[Union[str, Undefined]] = ...) -> Type[T]: ...
def dataclass_json(_cls: Optional[Type[T]] = None, *, letter_case: Optional[LetterCase] = None,
undefined: Optional[Union[str, Undefined]] = None) -> Union[Callable[[Type[T]], Type[T]], Type[T]]:
"""
Based on the code in the `dataclasses` module to handle optional-parens
decorators. See example below:
@dataclass_json
@dataclass_json(letter_case=LetterCase.CAMEL)
class Example:
...
"""
def wrap(cls: Type[T]) -> Type[T]:
return _process_class(cls, letter_case, undefined)
if _cls is None:
return wrap
return wrap(_cls)
def _process_class(cls: Type[T], letter_case: Optional[LetterCase],
undefined: Optional[Union[str, Undefined]]) -> Type[T]:
if letter_case is not None or undefined is not None:
cls.dataclass_json_config = config(letter_case=letter_case, # type: ignore[attr-defined]
undefined=undefined)['dataclasses_json']
cls.to_json = DataClassJsonMixin.to_json # type: ignore[attr-defined]
# unwrap and rewrap classmethod to tag it to cls rather than the literal
# DataClassJsonMixin ABC
cls.from_json = classmethod(DataClassJsonMixin.from_json.__func__) # type: ignore[attr-defined]
cls.to_dict = DataClassJsonMixin.to_dict # type: ignore[attr-defined]
cls.from_dict = classmethod(DataClassJsonMixin.from_dict.__func__) # type: ignore[attr-defined]
cls.schema = classmethod(DataClassJsonMixin.schema.__func__) # type: ignore[attr-defined]
cls.__init__ = _handle_undefined_parameters_safe(cls, kvs=(), # type: ignore[attr-defined,method-assign]
usage="init")
# register cls as a virtual subclass of DataClassJsonMixin
DataClassJsonMixin.register(cls)
return cls

View File

@ -0,0 +1,110 @@
import functools
from enum import Enum
from typing import Callable, Dict, Optional, TypeVar, Union
from marshmallow.fields import Field as MarshmallowField # type: ignore
from dataclasses_json.stringcase import (camelcase, pascalcase, snakecase,
spinalcase) # type: ignore
from dataclasses_json.undefined import Undefined, UndefinedParameterError
T = TypeVar("T")
class Exclude:
"""
Pre-defined constants for exclusion. By default, fields are configured to
be included.
"""
ALWAYS: Callable[[object], bool] = lambda _: True
NEVER: Callable[[object], bool] = lambda _: False
# TODO: add warnings?
class _GlobalConfig:
def __init__(self):
self.encoders: Dict[Union[type, Optional[type]], Callable] = {}
self.decoders: Dict[Union[type, Optional[type]], Callable] = {}
self.mm_fields: Dict[
Union[type, Optional[type]],
MarshmallowField
] = {}
# self._json_module = json
# TODO: #180
# @property
# def json_module(self):
# return self._json_module
#
# @json_module.setter
# def json_module(self, value):
# warnings.warn(f"Now using {value.__name__} module to handle JSON. "
# f"{self._disable_msg}")
# self._json_module = value
global_config = _GlobalConfig()
class LetterCase(Enum):
CAMEL = camelcase
KEBAB = spinalcase
SNAKE = snakecase
PASCAL = pascalcase
def config(metadata: Optional[dict] = None, *,
# TODO: these can be typed more precisely
# Specifically, a Callable[A, B], where `B` is bound as a JSON type
encoder: Optional[Callable] = None,
decoder: Optional[Callable] = None,
mm_field: Optional[MarshmallowField] = None,
letter_case: Union[Callable[[str], str], LetterCase, None] = None,
undefined: Optional[Union[str, Undefined]] = None,
field_name: Optional[str] = None,
exclude: Optional[Callable[[T], bool]] = None,
) -> Dict[str, dict]:
if metadata is None:
metadata = {}
lib_metadata = metadata.setdefault('dataclasses_json', {})
if encoder is not None:
lib_metadata['encoder'] = encoder
if decoder is not None:
lib_metadata['decoder'] = decoder
if mm_field is not None:
lib_metadata['mm_field'] = mm_field
if field_name is not None:
if letter_case is not None:
@functools.wraps(letter_case) # type:ignore
def override(_, _letter_case=letter_case, _field_name=field_name):
return _letter_case(_field_name)
else:
def override(_, _field_name=field_name): # type:ignore
return _field_name
letter_case = override
if letter_case is not None:
lib_metadata['letter_case'] = letter_case
if undefined is not None:
# Get the corresponding action for undefined parameters
if isinstance(undefined, str):
if not hasattr(Undefined, undefined.upper()):
valid_actions = list(action.name for action in Undefined)
raise UndefinedParameterError(
f"Invalid undefined parameter action, "
f"must be one of {valid_actions}")
undefined = Undefined[undefined.upper()]
lib_metadata['undefined'] = undefined
if exclude is not None:
lib_metadata['exclude'] = exclude
return metadata

View File

@ -0,0 +1,475 @@
import copy
import json
import sys
import warnings
from collections import defaultdict, namedtuple
from collections.abc import (Collection as ABCCollection, Mapping as ABCMapping, MutableMapping, MutableSequence,
MutableSet, Sequence, Set)
from dataclasses import (MISSING,
fields,
is_dataclass # type: ignore
)
from datetime import datetime, timezone
from decimal import Decimal
from enum import Enum
from types import MappingProxyType
from typing import (Any, Collection, Mapping, Union, get_type_hints,
Tuple, TypeVar, Type)
from uuid import UUID
from typing_inspect import is_union_type # type: ignore
from dataclasses_json import cfg
from dataclasses_json.utils import (_get_type_cons, _get_type_origin,
_handle_undefined_parameters_safe,
_is_collection, _is_mapping, _is_new_type,
_is_optional, _isinstance_safe,
_get_type_arg_param,
_get_type_args, _is_counter,
_NO_ARGS,
_issubclass_safe, _is_tuple,
_is_generic_dataclass)
Json = Union[dict, list, str, int, float, bool, None]
confs = ['encoder', 'decoder', 'mm_field', 'letter_case', 'exclude']
FieldOverride = namedtuple('FieldOverride', confs) # type: ignore
collections_abc_type_to_implementation_type = MappingProxyType({
ABCCollection: tuple,
ABCMapping: dict,
MutableMapping: dict,
MutableSequence: list,
MutableSet: set,
Sequence: tuple,
Set: frozenset,
})
class _ExtendedEncoder(json.JSONEncoder):
def default(self, o) -> Json:
result: Json
if _isinstance_safe(o, Collection):
if _isinstance_safe(o, Mapping):
result = dict(o)
else:
result = list(o)
elif _isinstance_safe(o, datetime):
result = o.timestamp()
elif _isinstance_safe(o, UUID):
result = str(o)
elif _isinstance_safe(o, Enum):
result = o.value
elif _isinstance_safe(o, Decimal):
result = str(o)
else:
result = json.JSONEncoder.default(self, o)
return result
def _user_overrides_or_exts(cls):
global_metadata = defaultdict(dict)
encoders = cfg.global_config.encoders
decoders = cfg.global_config.decoders
mm_fields = cfg.global_config.mm_fields
for field in fields(cls):
if field.type in encoders:
global_metadata[field.name]['encoder'] = encoders[field.type]
if field.type in decoders:
global_metadata[field.name]['decoder'] = decoders[field.type]
if field.type in mm_fields:
global_metadata[field.name]['mm_field'] = mm_fields[field.type]
try:
cls_config = (cls.dataclass_json_config
if cls.dataclass_json_config is not None else {})
except AttributeError:
cls_config = {}
overrides = {}
for field in fields(cls):
field_config = {}
# first apply global overrides or extensions
field_metadata = global_metadata[field.name]
if 'encoder' in field_metadata:
field_config['encoder'] = field_metadata['encoder']
if 'decoder' in field_metadata:
field_config['decoder'] = field_metadata['decoder']
if 'mm_field' in field_metadata:
field_config['mm_field'] = field_metadata['mm_field']
# then apply class-level overrides or extensions
field_config.update(cls_config)
# last apply field-level overrides or extensions
field_config.update(field.metadata.get('dataclasses_json', {}))
overrides[field.name] = FieldOverride(*map(field_config.get, confs))
return overrides
def _encode_json_type(value, default=_ExtendedEncoder().default):
if isinstance(value, Json.__args__): # type: ignore
if isinstance(value, list):
return [_encode_json_type(i) for i in value]
elif isinstance(value, dict):
return {k: _encode_json_type(v) for k, v in value.items()}
else:
return value
return default(value)
def _encode_overrides(kvs, overrides, encode_json=False):
override_kvs = {}
for k, v in kvs.items():
if k in overrides:
exclude = overrides[k].exclude
# If the exclude predicate returns true, the key should be
# excluded from encoding, so skip the rest of the loop
if exclude and exclude(v):
continue
letter_case = overrides[k].letter_case
original_key = k
k = letter_case(k) if letter_case is not None else k
if k in override_kvs:
raise ValueError(
f"Multiple fields map to the same JSON "
f"key after letter case encoding: {k}"
)
encoder = overrides[original_key].encoder
v = encoder(v) if encoder is not None else v
if encode_json:
v = _encode_json_type(v)
override_kvs[k] = v
return override_kvs
def _decode_letter_case_overrides(field_names, overrides):
"""Override letter case of field names for encode/decode"""
names = {}
for field_name in field_names:
field_override = overrides.get(field_name)
if field_override is not None:
letter_case = field_override.letter_case
if letter_case is not None:
names[letter_case(field_name)] = field_name
return names
def _decode_dataclass(cls, kvs, infer_missing):
if _isinstance_safe(kvs, cls):
return kvs
overrides = _user_overrides_or_exts(cls)
kvs = {} if kvs is None and infer_missing else kvs
field_names = [field.name for field in fields(cls)]
decode_names = _decode_letter_case_overrides(field_names, overrides)
kvs = {decode_names.get(k, k): v for k, v in kvs.items()}
missing_fields = {field for field in fields(cls) if field.name not in kvs}
for field in missing_fields:
if field.default is not MISSING:
kvs[field.name] = field.default
elif field.default_factory is not MISSING:
kvs[field.name] = field.default_factory()
elif infer_missing:
kvs[field.name] = None
# Perform undefined parameter action
kvs = _handle_undefined_parameters_safe(cls, kvs, usage="from")
init_kwargs = {}
types = get_type_hints(cls)
for field in fields(cls):
# The field should be skipped from being added
# to init_kwargs as it's not intended as a constructor argument.
if not field.init:
continue
field_value = kvs[field.name]
field_type = types[field.name]
if field_value is None:
if not _is_optional(field_type):
warning = (
f"value of non-optional type {field.name} detected "
f"when decoding {cls.__name__}"
)
if infer_missing:
warnings.warn(
f"Missing {warning} and was defaulted to None by "
f"infer_missing=True. "
f"Set infer_missing=False (the default) to prevent "
f"this behavior.", RuntimeWarning
)
else:
warnings.warn(
f"'NoneType' object {warning}.", RuntimeWarning
)
init_kwargs[field.name] = field_value
continue
while True:
if not _is_new_type(field_type):
break
field_type = field_type.__supertype__
if (field.name in overrides
and overrides[field.name].decoder is not None):
# FIXME hack
if field_type is type(field_value):
init_kwargs[field.name] = field_value
else:
init_kwargs[field.name] = overrides[field.name].decoder(
field_value)
elif is_dataclass(field_type):
# FIXME this is a band-aid to deal with the value already being
# serialized when handling nested marshmallow schema
# proper fix is to investigate the marshmallow schema generation
# code
if is_dataclass(field_value):
value = field_value
else:
value = _decode_dataclass(field_type, field_value,
infer_missing)
init_kwargs[field.name] = value
elif _is_supported_generic(field_type) and field_type != str:
init_kwargs[field.name] = _decode_generic(field_type,
field_value,
infer_missing)
else:
init_kwargs[field.name] = _support_extended_types(field_type,
field_value)
return cls(**init_kwargs)
def _decode_type(type_, value, infer_missing):
if _has_decoder_in_global_config(type_):
return _get_decoder_in_global_config(type_)(value)
if _is_supported_generic(type_):
return _decode_generic(type_, value, infer_missing)
if is_dataclass(type_) or is_dataclass(value):
return _decode_dataclass(type_, value, infer_missing)
return _support_extended_types(type_, value)
def _support_extended_types(field_type, field_value):
if _issubclass_safe(field_type, datetime):
# FIXME this is a hack to deal with mm already decoding
# the issue is we want to leverage mm fields' missing argument
# but need this for the object creation hook
if isinstance(field_value, datetime):
res = field_value
else:
tz = datetime.now(timezone.utc).astimezone().tzinfo
res = datetime.fromtimestamp(field_value, tz=tz)
elif _issubclass_safe(field_type, Decimal):
res = (field_value
if isinstance(field_value, Decimal)
else Decimal(field_value))
elif _issubclass_safe(field_type, UUID):
res = (field_value
if isinstance(field_value, UUID)
else UUID(field_value))
elif _issubclass_safe(field_type, (int, float, str, bool)):
res = (field_value
if isinstance(field_value, field_type)
else field_type(field_value))
else:
res = field_value
return res
def _is_supported_generic(type_):
if type_ is _NO_ARGS:
return False
not_str = not _issubclass_safe(type_, str)
is_enum = _issubclass_safe(type_, Enum)
is_generic_dataclass = _is_generic_dataclass(type_)
return (not_str and _is_collection(type_)) or _is_optional(
type_) or is_union_type(type_) or is_enum or is_generic_dataclass
def _decode_generic(type_, value, infer_missing):
if value is None:
res = value
elif _issubclass_safe(type_, Enum):
# Convert to an Enum using the type as a constructor.
# Assumes a direct match is found.
res = type_(value)
# FIXME this is a hack to fix a deeper underlying issue. A refactor is due.
elif _is_collection(type_):
if _is_mapping(type_) and not _is_counter(type_):
k_type, v_type = _get_type_args(type_, (Any, Any))
# a mapping type has `.keys()` and `.values()`
# (see collections.abc)
ks = _decode_dict_keys(k_type, value.keys(), infer_missing)
vs = _decode_items(v_type, value.values(), infer_missing)
xs = zip(ks, vs)
elif _is_tuple(type_):
types = _get_type_args(type_)
if Ellipsis in types:
xs = _decode_items(types[0], value, infer_missing)
else:
xs = _decode_items(_get_type_args(type_) or _NO_ARGS, value, infer_missing)
elif _is_counter(type_):
xs = dict(zip(_decode_items(_get_type_arg_param(type_, 0), value.keys(), infer_missing), value.values()))
else:
xs = _decode_items(_get_type_arg_param(type_, 0), value, infer_missing)
collection_type = _resolve_collection_type_to_decode_to(type_)
res = collection_type(xs)
elif _is_generic_dataclass(type_):
origin = _get_type_origin(type_)
res = _decode_dataclass(origin, value, infer_missing)
else: # Optional or Union
_args = _get_type_args(type_)
if _args is _NO_ARGS:
# Any, just accept
res = value
elif _is_optional(type_) and len(_args) == 2: # Optional
type_arg = _get_type_arg_param(type_, 0)
res = _decode_type(type_arg, value, infer_missing)
else: # Union (already decoded or try to decode a dataclass)
type_options = _get_type_args(type_)
res = value # assume already decoded
if type(value) is dict and dict not in type_options:
for type_option in type_options:
if is_dataclass(type_option):
try:
res = _decode_dataclass(type_option, value, infer_missing)
break
except (KeyError, ValueError, AttributeError):
continue
if res == value:
warnings.warn(
f"Failed to decode {value} Union dataclasses."
f"Expected Union to include a matching dataclass and it didn't."
)
return res
def _decode_dict_keys(key_type, xs, infer_missing):
"""
Because JSON object keys must be strs, we need the extra step of decoding
them back into the user's chosen python type
"""
decode_function = key_type
# handle NoneType keys... it's weird to type a Dict as NoneType keys
# but it's valid...
# Issue #341 and PR #346:
# This is a special case for Python 3.7 and Python 3.8.
# By some reason, "unbound" dicts are counted
# as having key type parameter to be TypeVar('KT')
if key_type is None or key_type == Any or isinstance(key_type, TypeVar):
decode_function = key_type = (lambda x: x)
# handle a nested python dict that has tuples for keys. E.g. for
# Dict[Tuple[int], int], key_type will be typing.Tuple[int], but
# decode_function should be tuple, so map() doesn't break.
#
# Note: _get_type_origin() will return typing.Tuple for python
# 3.6 and tuple for 3.7 and higher.
elif _get_type_origin(key_type) in {tuple, Tuple}:
decode_function = tuple
key_type = key_type
return map(decode_function, _decode_items(key_type, xs, infer_missing))
def _decode_items(type_args, xs, infer_missing):
"""
This is a tricky situation where we need to check both the annotated
type info (which is usually a type from `typing`) and check the
value's type directly using `type()`.
If the type_arg is a generic we can use the annotated type, but if the
type_arg is a typevar we need to extract the reified type information
hence the check of `is_dataclass(vs)`
"""
def handle_pep0673(pre_0673_hint: str) -> Union[Type, str]:
for module in sys.modules.values():
if hasattr(module, type_args):
maybe_resolved = getattr(module, type_args)
warnings.warn(f"Assuming hint {pre_0673_hint} resolves to {maybe_resolved} "
"This is not necessarily the value that is in-scope.")
return maybe_resolved
warnings.warn(f"Could not resolve self-reference for type {pre_0673_hint}, "
f"decoded type might be incorrect or decode might fail altogether.")
return pre_0673_hint
# Before https://peps.python.org/pep-0673 (3.11+) self-type hints are simply strings
if sys.version_info.minor < 11 and type_args is not type and type(type_args) is str:
type_args = handle_pep0673(type_args)
if _isinstance_safe(type_args, Collection) and not _issubclass_safe(type_args, Enum):
if len(type_args) == len(xs):
return list(_decode_type(type_arg, x, infer_missing) for type_arg, x in zip(type_args, xs))
else:
raise TypeError(f"Number of types specified in the collection type {str(type_args)} "
f"does not match number of elements in the collection. In case you are working with tuples"
f"take a look at this document "
f"docs.python.org/3/library/typing.html#annotating-tuples.")
return list(_decode_type(type_args, x, infer_missing) for x in xs)
def _resolve_collection_type_to_decode_to(type_):
# get the constructor if using corresponding generic type in `typing`
# otherwise fallback on constructing using type_ itself
try:
collection_type = _get_type_cons(type_)
except (TypeError, AttributeError):
collection_type = type_
# map abstract collection to concrete implementation
return collections_abc_type_to_implementation_type.get(collection_type, collection_type)
def _asdict(obj, encode_json=False):
"""
A re-implementation of `asdict` (based on the original in the `dataclasses`
source) to support arbitrary Collection and Mapping types.
"""
if is_dataclass(obj):
result = []
overrides = _user_overrides_or_exts(obj)
for field in fields(obj):
if overrides[field.name].encoder:
value = getattr(obj, field.name)
else:
value = _asdict(
getattr(obj, field.name),
encode_json=encode_json
)
result.append((field.name, value))
result = _handle_undefined_parameters_safe(cls=obj, kvs=dict(result),
usage="to")
return _encode_overrides(dict(result), _user_overrides_or_exts(obj),
encode_json=encode_json)
elif isinstance(obj, Mapping):
return dict((_asdict(k, encode_json=encode_json),
_asdict(v, encode_json=encode_json)) for k, v in
obj.items())
# enum.IntFlag and enum.Flag are regarded as collections in Python 3.11, thus a check against Enum is needed
elif isinstance(obj, Collection) and not isinstance(obj, (str, bytes, Enum)):
return list(_asdict(v, encode_json=encode_json) for v in obj)
# encoding of generics primarily relies on concrete types while decoding relies on type annotations. This makes
# applying encoders/decoders from global configuration inconsistent.
elif _has_encoder_in_global_config(type(obj)):
return _get_encoder_in_global_config(type(obj))(obj)
else:
return copy.deepcopy(obj)
def _has_decoder_in_global_config(type_):
return type_ in cfg.global_config.decoders
def _get_decoder_in_global_config(type_):
return cfg.global_config.decoders[type_]
def _has_encoder_in_global_config(type_):
return type_ in cfg.global_config.encoders
def _get_encoder_in_global_config(type_):
return cfg.global_config.encoders[type_]

View File

@ -0,0 +1,399 @@
# flake8: noqa
import typing
import warnings
import sys
from copy import deepcopy
from dataclasses import MISSING, is_dataclass, fields as dc_fields
from datetime import datetime
from decimal import Decimal
from uuid import UUID
from enum import Enum
from typing_inspect import is_union_type # type: ignore
from marshmallow import fields, Schema, post_load # type: ignore
from marshmallow.exceptions import ValidationError # type: ignore
from dataclasses_json.core import (_is_supported_generic, _decode_dataclass,
_ExtendedEncoder, _user_overrides_or_exts)
from dataclasses_json.utils import (_is_collection, _is_optional,
_issubclass_safe, _timestamp_to_dt_aware,
_is_new_type, _get_type_origin,
_handle_undefined_parameters_safe,
CatchAllVar)
class _TimestampField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if value is not None:
return value.timestamp()
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
def _deserialize(self, value, attr, data, **kwargs):
if value is not None:
return _timestamp_to_dt_aware(value)
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
class _IsoField(fields.Field):
def _serialize(self, value, attr, obj, **kwargs):
if value is not None:
return value.isoformat()
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
def _deserialize(self, value, attr, data, **kwargs):
if value is not None:
return datetime.fromisoformat(value)
else:
if not self.required:
return None
else:
raise ValidationError(self.default_error_messages["required"])
class _UnionField(fields.Field):
def __init__(self, desc, cls, field, *args, **kwargs):
self.desc = desc
self.cls = cls
self.field = field
super().__init__(*args, **kwargs)
def _serialize(self, value, attr, obj, **kwargs):
if self.allow_none and value is None:
return None
for type_, schema_ in self.desc.items():
if _issubclass_safe(type(value), type_):
if is_dataclass(value):
res = schema_._serialize(value, attr, obj, **kwargs)
res['__type'] = str(type_.__name__)
return res
break
elif isinstance(value, _get_type_origin(type_)):
return schema_._serialize(value, attr, obj, **kwargs)
else:
warnings.warn(
f'The type "{type(value).__name__}" (value: "{value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be serialized properly.')
return super()._serialize(value, attr, obj, **kwargs)
def _deserialize(self, value, attr, data, **kwargs):
tmp_value = deepcopy(value)
if isinstance(tmp_value, dict) and '__type' in tmp_value:
dc_name = tmp_value['__type']
for type_, schema_ in self.desc.items():
if is_dataclass(type_) and type_.__name__ == dc_name:
del tmp_value['__type']
return schema_._deserialize(tmp_value, attr, data, **kwargs)
elif isinstance(tmp_value, dict):
warnings.warn(
f'Attempting to deserialize "dict" (value: "{tmp_value}) '
f'that does not have a "__type" type specifier field into'
f'(dataclass: {self.cls.__name__}, field: {self.field.name}).'
f'Deserialization may fail, or deserialization to wrong type may occur.'
)
return super()._deserialize(tmp_value, attr, data, **kwargs)
else:
for type_, schema_ in self.desc.items():
if isinstance(tmp_value, _get_type_origin(type_)):
return schema_._deserialize(tmp_value, attr, data, **kwargs)
else:
warnings.warn(
f'The type "{type(tmp_value).__name__}" (value: "{tmp_value}") '
f'is not in the list of possible types of typing.Union '
f'(dataclass: {self.cls.__name__}, field: {self.field.name}). '
f'Value cannot be deserialized properly.')
return super()._deserialize(tmp_value, attr, data, **kwargs)
class _TupleVarLen(fields.List):
"""
variable-length homogeneous tuples
"""
def _deserialize(self, value, attr, data, **kwargs):
optional_list = super()._deserialize(value, attr, data, **kwargs)
return None if optional_list is None else tuple(optional_list)
TYPES = {
typing.Mapping: fields.Mapping,
typing.MutableMapping: fields.Mapping,
typing.List: fields.List,
typing.Dict: fields.Dict,
typing.Tuple: fields.Tuple,
typing.Callable: fields.Function,
typing.Any: fields.Raw,
dict: fields.Dict,
list: fields.List,
tuple: fields.Tuple,
str: fields.Str,
int: fields.Int,
float: fields.Float,
bool: fields.Bool,
datetime: _TimestampField,
UUID: fields.UUID,
Decimal: fields.Decimal,
CatchAllVar: fields.Dict,
}
A = typing.TypeVar('A')
JsonData = typing.Union[str, bytes, bytearray]
TEncoded = typing.Dict[str, typing.Any]
TOneOrMulti = typing.Union[typing.List[A], A]
TOneOrMultiEncoded = typing.Union[typing.List[TEncoded], TEncoded]
if sys.version_info >= (3, 7) or typing.TYPE_CHECKING:
class SchemaF(Schema, typing.Generic[A]):
"""Lift Schema into a type constructor"""
def __init__(self, *args, **kwargs):
"""
Raises exception because this class should not be inherited.
This class is helper only.
"""
super().__init__(*args, **kwargs)
raise NotImplementedError()
@typing.overload
def dump(self, obj: typing.List[A], many: typing.Optional[bool] = None) -> typing.List[TEncoded]: # type: ignore
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
pass
@typing.overload
def dump(self, obj: A, many: typing.Optional[bool] = None) -> TEncoded:
pass
def dump(self, obj: TOneOrMulti, # type: ignore
many: typing.Optional[bool] = None) -> TOneOrMultiEncoded:
pass
@typing.overload
def dumps(self, obj: typing.List[A], many: typing.Optional[bool] = None, *args,
**kwargs) -> str:
pass
@typing.overload
def dumps(self, obj: A, many: typing.Optional[bool] = None, *args, **kwargs) -> str:
pass
def dumps(self, obj: TOneOrMulti, many: typing.Optional[bool] = None, *args, # type: ignore
**kwargs) -> str:
pass
@typing.overload # type: ignore
def load(self, data: typing.List[TEncoded],
many: bool = True, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> \
typing.List[A]:
# ignore the mypy error of the decorator because mm does not define lists as an allowed input type
pass
@typing.overload
def load(self, data: TEncoded,
many: None = None, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> A:
pass
def load(self, data: TOneOrMultiEncoded,
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None,
unknown: typing.Optional[str] = None) -> TOneOrMulti:
pass
@typing.overload # type: ignore
def loads(self, json_data: JsonData, # type: ignore
many: typing.Optional[bool] = True, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
**kwargs) -> typing.List[A]:
# ignore the mypy error of the decorator because mm does not define bytes as correct input data
# mm has the wrong return type annotation (dict) so we can ignore the mypy error
# for the return type overlap
pass
def loads(self, json_data: JsonData,
many: typing.Optional[bool] = None, partial: typing.Optional[bool] = None, unknown: typing.Optional[str] = None,
**kwargs) -> TOneOrMulti:
pass
SchemaType = SchemaF[A]
else:
SchemaType = Schema
def build_type(type_, options, mixin, field, cls):
def inner(type_, options):
while True:
if not _is_new_type(type_):
break
type_ = type_.__supertype__
if is_dataclass(type_):
if _issubclass_safe(type_, mixin):
options['field_many'] = bool(
_is_supported_generic(field.type) and _is_collection(
field.type))
return fields.Nested(type_.schema(), **options)
else:
warnings.warn(f"Nested dataclass field {field.name} of type "
f"{field.type} detected in "
f"{cls.__name__} that is not an instance of "
f"dataclass_json. Did you mean to recursively "
f"serialize this field? If so, make sure to "
f"augment {type_} with either the "
f"`dataclass_json` decorator or mixin.")
return fields.Field(**options)
origin = getattr(type_, '__origin__', type_)
args = [inner(a, {}) for a in getattr(type_, '__args__', []) if
a is not type(None)]
if type_ == Ellipsis:
return type_
if _is_optional(type_):
options["allow_none"] = True
if origin is tuple:
if len(args) == 2 and args[1] == Ellipsis:
return _TupleVarLen(args[0], **options)
else:
return fields.Tuple(args, **options)
if origin in TYPES:
return TYPES[origin](*args, **options)
if _issubclass_safe(origin, Enum):
return fields.Enum(enum=origin, by_value=True, *args, **options)
if is_union_type(type_):
union_types = [a for a in getattr(type_, '__args__', []) if
a is not type(None)]
union_desc = dict(zip(union_types, args))
return _UnionField(union_desc, cls, field, **options)
warnings.warn(
f"Unknown type {type_} at {cls.__name__}.{field.name}: {field.type} "
f"It's advised to pass the correct marshmallow type to `mm_field`.")
return fields.Field(**options)
return inner(type_, options)
def schema(cls, mixin, infer_missing):
schema = {}
overrides = _user_overrides_or_exts(cls)
# TODO check the undefined parameters and add the proper schema action
# https://marshmallow.readthedocs.io/en/stable/quickstart.html
for field in dc_fields(cls):
metadata = overrides[field.name]
if metadata.mm_field is not None:
schema[field.name] = metadata.mm_field
else:
type_ = field.type
options: typing.Dict[str, typing.Any] = {}
missing_key = 'missing' if infer_missing else 'default'
if field.default is not MISSING:
options[missing_key] = field.default
elif field.default_factory is not MISSING:
options[missing_key] = field.default_factory()
else:
options['required'] = True
if options.get(missing_key, ...) is None:
options['allow_none'] = True
if _is_optional(type_):
options.setdefault(missing_key, None)
options['allow_none'] = True
if len(type_.__args__) == 2:
# Union[str, int, None] is optional too, but it has more than 1 typed field.
type_ = [tp for tp in type_.__args__ if tp is not type(None)][0]
if metadata.letter_case is not None:
options['data_key'] = metadata.letter_case(field.name)
t = build_type(type_, options, mixin, field, cls)
if field.metadata.get('dataclasses_json', {}).get('decoder'):
# If the field defines a custom decoder, it should completely replace the Marshmallow field's conversion
# logic.
# From Marshmallow's documentation for the _deserialize method:
# "Deserialize value. Concrete :class:`Field` classes should implement this method. "
# This is the method that Field implementations override to perform the actual deserialization logic.
# In this case we specifically override this method instead of `deserialize` to minimize potential
# side effects, and only cancel the actual value deserialization.
t._deserialize = lambda v, *_a, **_kw: v
# if type(t) is not fields.Field: # If we use `isinstance` we would return nothing.
if field.type != typing.Optional[CatchAllVar]:
schema[field.name] = t
return schema
def build_schema(cls: typing.Type[A],
mixin,
infer_missing,
partial) -> typing.Type["SchemaType[A]"]:
Meta = type('Meta',
(),
{'fields': tuple(field.name for field in dc_fields(cls) # type: ignore
if
field.name != 'dataclass_json_config' and field.type !=
typing.Optional[CatchAllVar]),
# TODO #180
# 'render_module': global_config.json_module
})
@post_load
def make_instance(self, kvs, **kwargs):
return _decode_dataclass(cls, kvs, partial)
def dumps(self, *args, **kwargs):
if 'cls' not in kwargs:
kwargs['cls'] = _ExtendedEncoder
return Schema.dumps(self, *args, **kwargs)
def dump(self, obj, *, many=None):
many = self.many if many is None else bool(many)
dumped = Schema.dump(self, obj, many=many)
# TODO This is hacky, but the other option I can think of is to generate a different schema
# depending on dump and load, which is even more hacky
# The only problem is the catch-all field, we can't statically create a schema for it,
# so we just update the dumped dict
if many:
for i, _obj in enumerate(obj):
dumped[i].update(
_handle_undefined_parameters_safe(cls=_obj, kvs={},
usage="dump"))
else:
dumped.update(_handle_undefined_parameters_safe(cls=obj, kvs={},
usage="dump"))
return dumped
schema_ = schema(cls, mixin, infer_missing)
DataClassSchema: typing.Type["SchemaType[A]"] = type(
f'{cls.__name__.capitalize()}Schema',
(Schema,),
{'Meta': Meta,
f'make_{cls.__name__.lower()}': make_instance,
'dumps': dumps,
'dump': dump,
**schema_})
return DataClassSchema

View File

@ -0,0 +1,130 @@
# The MIT License (MIT)
#
# Copyright (c) 2015 Taka Okunishi
#
# Permission is hereby granted, free of charge, to any person obtaining a copy
# of this software and associated documentation files (the "Software"), to deal
# in the Software without restriction, including without limitation the rights
# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
# copies of the Software, and to permit persons to whom the Software is
# furnished to do so, subject to the following conditions:
#
# The above copyright notice and this permission notice shall be included in
# all copies or substantial portions of the Software.
#
# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
# SOFTWARE.
#
# Copyright © 2015-2018 Taka Okunishi <okunishinishi@gmail.com>.
# Copyright © 2020 Louis-Philippe Véronneau <pollo@debian.org>
import re
def uplowcase(string, case):
"""Convert string into upper or lower case.
Args:
string: String to convert.
Returns:
string: Uppercase or lowercase case string.
"""
if case == 'up':
return str(string).upper()
elif case == 'low':
return str(string).lower()
def capitalcase(string):
"""Convert string into capital case.
First letters will be uppercase.
Args:
string: String to convert.
Returns:
string: Capital case string.
"""
string = str(string)
if not string:
return string
return uplowcase(string[0], 'up') + string[1:]
def camelcase(string):
""" Convert string into camel case.
Args:
string: String to convert.
Returns:
string: Camel case string.
"""
string = re.sub(r"^[\-_\.]", '', str(string))
if not string:
return string
return (uplowcase(string[0], 'low')
+ re.sub(r"[\-_\.\s]([a-z0-9])",
lambda matched: uplowcase(matched.group(1), 'up'),
string[1:]))
def snakecase(string):
"""Convert string into snake case.
Join punctuation with underscore
Args:
string: String to convert.
Returns:
string: Snake cased string.
"""
string = re.sub(r"[\-\.\s]", '_', str(string))
if not string:
return string
return (uplowcase(string[0], 'low')
+ re.sub(r"[A-Z0-9]",
lambda matched: '_' + uplowcase(matched.group(0), 'low'),
string[1:]))
def spinalcase(string):
"""Convert string into spinal case.
Join punctuation with hyphen.
Args:
string: String to convert.
Returns:
string: Spinal cased string.
"""
return re.sub(r"_", "-", snakecase(string))
def pascalcase(string):
"""Convert string into pascal case.
Args:
string: String to convert.
Returns:
string: Pascal case string.
"""
return capitalcase(camelcase(string))

View File

@ -0,0 +1,280 @@
import abc
import dataclasses
import functools
import inspect
import sys
from dataclasses import Field, fields
from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints
from enum import Enum
from marshmallow.exceptions import ValidationError # type: ignore
from dataclasses_json.utils import CatchAllVar
KnownParameters = Dict[str, Any]
UnknownParameters = Dict[str, Any]
class _UndefinedParameterAction(abc.ABC):
@staticmethod
@abc.abstractmethod
def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]:
"""
Return the parameters to initialize the class with.
"""
pass
@staticmethod
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
"""
Return the parameters that will be written to the output dict
"""
return kvs
@staticmethod
def handle_dump(obj) -> Dict[Any, Any]:
"""
Return the parameters that will be added to the schema dump.
"""
return {}
@staticmethod
def create_init(obj) -> Callable:
return obj.__init__
@staticmethod
def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \
Tuple[KnownParameters, UnknownParameters]:
"""
Returns a 2 dictionaries: defined and undefined parameters
"""
class_fields = fields(cls)
field_names = [field.name for field in class_fields]
unknown_given_parameters = {k: v for k, v in kvs.items() if
k not in field_names}
known_given_parameters = {k: v for k, v in kvs.items() if
k in field_names}
return known_given_parameters, unknown_given_parameters
class _RaiseUndefinedParameters(_UndefinedParameterAction):
"""
This action raises UndefinedParameterError if it encounters an undefined
parameter during initialization.
"""
@staticmethod
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
known, unknown = \
_UndefinedParameterAction._separate_defined_undefined_kvs(
cls=cls, kvs=kvs)
if len(unknown) > 0:
raise UndefinedParameterError(
f"Received undefined initialization arguments {unknown}")
return known
CatchAll = Optional[CatchAllVar]
class _IgnoreUndefinedParameters(_UndefinedParameterAction):
"""
This action does nothing when it encounters undefined parameters.
The undefined parameters can not be retrieved after the class has been
created.
"""
@staticmethod
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
known_given_parameters, _ = \
_UndefinedParameterAction._separate_defined_undefined_kvs(
cls=cls, kvs=kvs)
return known_given_parameters
@staticmethod
def create_init(obj) -> Callable:
original_init = obj.__init__
init_signature = inspect.signature(original_init)
@functools.wraps(obj.__init__)
def _ignore_init(self, *args, **kwargs):
known_kwargs, _ = \
_CatchAllUndefinedParameters._separate_defined_undefined_kvs(
obj, kwargs)
num_params_takeable = len(
init_signature.parameters) - 1 # don't count self
num_args_takeable = num_params_takeable - len(known_kwargs)
args = args[:num_args_takeable]
bound_parameters = init_signature.bind_partial(self, *args,
**known_kwargs)
bound_parameters.apply_defaults()
arguments = bound_parameters.arguments
arguments.pop("self", None)
final_parameters = \
_IgnoreUndefinedParameters.handle_from_dict(obj, arguments)
original_init(self, **final_parameters)
return _ignore_init
class _CatchAllUndefinedParameters(_UndefinedParameterAction):
"""
This class allows to add a field of type utils.CatchAll which acts as a
dictionary into which all
undefined parameters will be written.
These parameters are not affected by LetterCase.
If no undefined parameters are given, this dictionary will be empty.
"""
class _SentinelNoDefault:
pass
@staticmethod
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
known, unknown = _UndefinedParameterAction \
._separate_defined_undefined_kvs(cls=cls, kvs=kvs)
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
cls=cls)
if catch_all_field.name in known:
already_parsed = isinstance(known[catch_all_field.name], dict)
default_value = _CatchAllUndefinedParameters._get_default(
catch_all_field=catch_all_field)
received_default = default_value == known[catch_all_field.name]
value_to_write: Any
if received_default and len(unknown) == 0:
value_to_write = default_value
elif received_default and len(unknown) > 0:
value_to_write = unknown
elif already_parsed:
# Did not receive default
value_to_write = known[catch_all_field.name]
if len(unknown) > 0:
value_to_write.update(unknown)
else:
error_message = f"Received input field with " \
f"same name as catch-all field: " \
f"'{catch_all_field.name}': " \
f"'{known[catch_all_field.name]}'"
raise UndefinedParameterError(error_message)
else:
value_to_write = unknown
known[catch_all_field.name] = value_to_write
return known
@staticmethod
def _get_default(catch_all_field: Field) -> Any:
# access to the default factory currently causes
# a false-positive mypy error (16. Dec 2019):
# https://github.com/python/mypy/issues/6910
# noinspection PyProtectedMember
has_default = not isinstance(catch_all_field.default,
dataclasses._MISSING_TYPE)
# noinspection PyProtectedMember
has_default_factory = not isinstance(catch_all_field.default_factory,
# type: ignore
dataclasses._MISSING_TYPE)
# TODO: black this for proper formatting
default_value: Union[
Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\
._SentinelNoDefault
if has_default:
default_value = catch_all_field.default
elif has_default_factory:
# This might be unwanted if the default factory constructs
# something expensive,
# because we have to construct it again just for this test
default_value = catch_all_field.default_factory() # type: ignore
return default_value
@staticmethod
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
catch_all_field = \
_CatchAllUndefinedParameters._get_catch_all_field(obj.__class__)
undefined_parameters = kvs.pop(catch_all_field.name)
if isinstance(undefined_parameters, dict):
kvs.update(
undefined_parameters) # If desired handle letter case here
return kvs
@staticmethod
def handle_dump(obj) -> Dict[Any, Any]:
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
cls=obj)
return getattr(obj, catch_all_field.name)
@staticmethod
def create_init(obj) -> Callable:
original_init = obj.__init__
init_signature = inspect.signature(original_init)
@functools.wraps(obj.__init__)
def _catch_all_init(self, *args, **kwargs):
known_kwargs, unknown_kwargs = \
_CatchAllUndefinedParameters._separate_defined_undefined_kvs(
obj, kwargs)
num_params_takeable = len(
init_signature.parameters) - 1 # don't count self
if _CatchAllUndefinedParameters._get_catch_all_field(
obj).name not in known_kwargs:
num_params_takeable -= 1
num_args_takeable = num_params_takeable - len(known_kwargs)
args, unknown_args = args[:num_args_takeable], args[
num_args_takeable:]
bound_parameters = init_signature.bind_partial(self, *args,
**known_kwargs)
unknown_args = {f"_UNKNOWN{i}": v for i, v in
enumerate(unknown_args)}
arguments = bound_parameters.arguments
arguments.update(unknown_args)
arguments.update(unknown_kwargs)
arguments.pop("self", None)
final_parameters = _CatchAllUndefinedParameters.handle_from_dict(
obj, arguments)
original_init(self, **final_parameters)
return _catch_all_init
@staticmethod
def _get_catch_all_field(cls) -> Field:
cls_globals = vars(sys.modules[cls.__module__])
types = get_type_hints(cls, globalns=cls_globals)
catch_all_fields = list(
filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls)))
number_of_catch_all_fields = len(catch_all_fields)
if number_of_catch_all_fields == 0:
raise UndefinedParameterError(
"No field of type dataclasses_json.CatchAll defined")
elif number_of_catch_all_fields > 1:
raise UndefinedParameterError(
f"Multiple catch-all fields supplied: "
f"{number_of_catch_all_fields}.")
else:
return catch_all_fields[0]
class Undefined(Enum):
"""
Choose the behavior what happens when an undefined parameter is encountered
during class initialization.
"""
INCLUDE = _CatchAllUndefinedParameters
RAISE = _RaiseUndefinedParameters
EXCLUDE = _IgnoreUndefinedParameters
class UndefinedParameterError(ValidationError):
"""
Raised when something has gone wrong handling undefined parameters.
"""
pass

View File

@ -0,0 +1,219 @@
import inspect
import sys
from datetime import datetime, timezone
from collections import Counter
from dataclasses import is_dataclass # type: ignore
from typing import (Collection, Mapping, Optional, TypeVar, Any, Type, Tuple,
Union, cast)
def _get_type_cons(type_):
"""More spaghetti logic for 3.6 vs. 3.7"""
if sys.version_info.minor == 6:
try:
cons = type_.__extra__
except AttributeError:
try:
cons = type_.__origin__
except AttributeError:
cons = type_
else:
cons = type_ if cons is None else cons
else:
try:
cons = type_.__origin__ if cons is None else cons
except AttributeError:
cons = type_
else:
cons = type_.__origin__
return cons
_NO_TYPE_ORIGIN = object()
def _get_type_origin(type_):
"""Some spaghetti logic to accommodate differences between 3.6 and 3.7 in
the typing api"""
try:
origin = type_.__origin__
except AttributeError:
# Issue #341 and PR #346:
# For some cases, the type_.__origin__ exists but is set to None
origin = _NO_TYPE_ORIGIN
if sys.version_info.minor == 6:
try:
origin = type_.__extra__
except AttributeError:
origin = type_
else:
origin = type_ if origin in (None, _NO_TYPE_ORIGIN) else origin
elif origin is _NO_TYPE_ORIGIN:
origin = type_
return origin
def _hasargs(type_, *args):
try:
res = all(arg in type_.__args__ for arg in args)
except AttributeError:
return False
except TypeError:
if (type_.__args__ is None):
return False
else:
raise
else:
return res
class _NoArgs(object):
def __bool__(self):
return False
def __len__(self):
return 0
def __iter__(self):
return self
def __next__(self):
raise StopIteration
_NO_ARGS = _NoArgs()
def _get_type_args(tp: Type, default: Union[Tuple[Type, ...], _NoArgs] = _NO_ARGS) -> \
Union[Tuple[Type, ...], _NoArgs]:
if hasattr(tp, '__args__'):
if tp.__args__ is not None:
return tp.__args__
return default
def _get_type_arg_param(tp: Type, index: int) -> Union[Type, _NoArgs]:
_args = _get_type_args(tp)
if _args is not _NO_ARGS:
try:
return cast(Tuple[Type, ...], _args)[index]
except (TypeError, IndexError, NotImplementedError):
pass
return _NO_ARGS
def _isinstance_safe(o, t):
try:
result = isinstance(o, t)
except Exception:
return False
else:
return result
def _issubclass_safe(cls, classinfo):
try:
return issubclass(cls, classinfo)
except Exception:
return (_is_new_type_subclass_safe(cls, classinfo)
if _is_new_type(cls)
else False)
def _is_new_type_subclass_safe(cls, classinfo):
super_type = getattr(cls, "__supertype__", None)
if super_type:
return _is_new_type_subclass_safe(super_type, classinfo)
try:
return issubclass(cls, classinfo)
except Exception:
return False
def _is_new_type(type_):
return inspect.isfunction(type_) and hasattr(type_, "__supertype__")
def _is_optional(type_):
return (_issubclass_safe(type_, Optional) or
_hasargs(type_, type(None)) or
type_ is Any)
def _is_counter(type_):
return _issubclass_safe(_get_type_origin(type_), Counter)
def _is_mapping(type_):
return _issubclass_safe(_get_type_origin(type_), Mapping)
def _is_collection(type_):
return _issubclass_safe(_get_type_origin(type_), Collection)
def _is_tuple(type_):
return _issubclass_safe(_get_type_origin(type_), Tuple)
def _is_nonstr_collection(type_):
return (_issubclass_safe(_get_type_origin(type_), Collection)
and not _issubclass_safe(type_, str))
def _is_generic_dataclass(type_):
return is_dataclass(_get_type_origin(type_))
def _timestamp_to_dt_aware(timestamp: float):
tz = datetime.now(timezone.utc).astimezone().tzinfo
dt = datetime.fromtimestamp(timestamp, tz=tz)
return dt
def _undefined_parameter_action_safe(cls):
try:
if cls.dataclass_json_config is None:
return
action_enum = cls.dataclass_json_config['undefined']
except (AttributeError, KeyError):
return
if action_enum is None or action_enum.value is None:
return
return action_enum
def _handle_undefined_parameters_safe(cls, kvs, usage: str):
"""
Checks if an undefined parameters action is defined and performs the
according action.
"""
undefined_parameter_action = _undefined_parameter_action_safe(cls)
usage = usage.lower()
if undefined_parameter_action is None:
return kvs if usage != "init" else cls.__init__
if usage == "from":
return undefined_parameter_action.value.handle_from_dict(cls=cls,
kvs=kvs)
elif usage == "to":
return undefined_parameter_action.value.handle_to_dict(obj=cls,
kvs=kvs)
elif usage == "dump":
return undefined_parameter_action.value.handle_dump(obj=cls)
elif usage == "init":
return undefined_parameter_action.value.create_init(obj=cls)
else:
raise ValueError(
f"usage must be one of ['to', 'from', 'dump', 'init'], "
f"but is '{usage}'")
# Define a type for the CatchAll field
# https://stackoverflow.com/questions/59360567/define-a-custom-type-that-behaves-like-typing-any
CatchAllVar = TypeVar("CatchAllVar", bound=Mapping)