second commit
This commit is contained in:
280
env/lib/python3.11/site-packages/dataclasses_json/undefined.py
vendored
Normal file
280
env/lib/python3.11/site-packages/dataclasses_json/undefined.py
vendored
Normal file
@ -0,0 +1,280 @@
|
||||
import abc
|
||||
import dataclasses
|
||||
import functools
|
||||
import inspect
|
||||
import sys
|
||||
from dataclasses import Field, fields
|
||||
from typing import Any, Callable, Dict, Optional, Tuple, Union, Type, get_type_hints
|
||||
from enum import Enum
|
||||
|
||||
from marshmallow.exceptions import ValidationError # type: ignore
|
||||
|
||||
from dataclasses_json.utils import CatchAllVar
|
||||
|
||||
KnownParameters = Dict[str, Any]
|
||||
UnknownParameters = Dict[str, Any]
|
||||
|
||||
|
||||
class _UndefinedParameterAction(abc.ABC):
|
||||
@staticmethod
|
||||
@abc.abstractmethod
|
||||
def handle_from_dict(cls, kvs: Dict[Any, Any]) -> Dict[str, Any]:
|
||||
"""
|
||||
Return the parameters to initialize the class with.
|
||||
"""
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
"""
|
||||
Return the parameters that will be written to the output dict
|
||||
"""
|
||||
return kvs
|
||||
|
||||
@staticmethod
|
||||
def handle_dump(obj) -> Dict[Any, Any]:
|
||||
"""
|
||||
Return the parameters that will be added to the schema dump.
|
||||
"""
|
||||
return {}
|
||||
|
||||
@staticmethod
|
||||
def create_init(obj) -> Callable:
|
||||
return obj.__init__
|
||||
|
||||
@staticmethod
|
||||
def _separate_defined_undefined_kvs(cls, kvs: Dict) -> \
|
||||
Tuple[KnownParameters, UnknownParameters]:
|
||||
"""
|
||||
Returns a 2 dictionaries: defined and undefined parameters
|
||||
"""
|
||||
class_fields = fields(cls)
|
||||
field_names = [field.name for field in class_fields]
|
||||
unknown_given_parameters = {k: v for k, v in kvs.items() if
|
||||
k not in field_names}
|
||||
known_given_parameters = {k: v for k, v in kvs.items() if
|
||||
k in field_names}
|
||||
return known_given_parameters, unknown_given_parameters
|
||||
|
||||
|
||||
class _RaiseUndefinedParameters(_UndefinedParameterAction):
|
||||
"""
|
||||
This action raises UndefinedParameterError if it encounters an undefined
|
||||
parameter during initialization.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
|
||||
known, unknown = \
|
||||
_UndefinedParameterAction._separate_defined_undefined_kvs(
|
||||
cls=cls, kvs=kvs)
|
||||
if len(unknown) > 0:
|
||||
raise UndefinedParameterError(
|
||||
f"Received undefined initialization arguments {unknown}")
|
||||
return known
|
||||
|
||||
|
||||
CatchAll = Optional[CatchAllVar]
|
||||
|
||||
|
||||
class _IgnoreUndefinedParameters(_UndefinedParameterAction):
|
||||
"""
|
||||
This action does nothing when it encounters undefined parameters.
|
||||
The undefined parameters can not be retrieved after the class has been
|
||||
created.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
|
||||
known_given_parameters, _ = \
|
||||
_UndefinedParameterAction._separate_defined_undefined_kvs(
|
||||
cls=cls, kvs=kvs)
|
||||
return known_given_parameters
|
||||
|
||||
@staticmethod
|
||||
def create_init(obj) -> Callable:
|
||||
original_init = obj.__init__
|
||||
init_signature = inspect.signature(original_init)
|
||||
|
||||
@functools.wraps(obj.__init__)
|
||||
def _ignore_init(self, *args, **kwargs):
|
||||
known_kwargs, _ = \
|
||||
_CatchAllUndefinedParameters._separate_defined_undefined_kvs(
|
||||
obj, kwargs)
|
||||
num_params_takeable = len(
|
||||
init_signature.parameters) - 1 # don't count self
|
||||
num_args_takeable = num_params_takeable - len(known_kwargs)
|
||||
|
||||
args = args[:num_args_takeable]
|
||||
bound_parameters = init_signature.bind_partial(self, *args,
|
||||
**known_kwargs)
|
||||
bound_parameters.apply_defaults()
|
||||
|
||||
arguments = bound_parameters.arguments
|
||||
arguments.pop("self", None)
|
||||
final_parameters = \
|
||||
_IgnoreUndefinedParameters.handle_from_dict(obj, arguments)
|
||||
original_init(self, **final_parameters)
|
||||
|
||||
return _ignore_init
|
||||
|
||||
|
||||
class _CatchAllUndefinedParameters(_UndefinedParameterAction):
|
||||
"""
|
||||
This class allows to add a field of type utils.CatchAll which acts as a
|
||||
dictionary into which all
|
||||
undefined parameters will be written.
|
||||
These parameters are not affected by LetterCase.
|
||||
If no undefined parameters are given, this dictionary will be empty.
|
||||
"""
|
||||
|
||||
class _SentinelNoDefault:
|
||||
pass
|
||||
|
||||
@staticmethod
|
||||
def handle_from_dict(cls, kvs: Dict) -> Dict[str, Any]:
|
||||
known, unknown = _UndefinedParameterAction \
|
||||
._separate_defined_undefined_kvs(cls=cls, kvs=kvs)
|
||||
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
|
||||
cls=cls)
|
||||
|
||||
if catch_all_field.name in known:
|
||||
|
||||
already_parsed = isinstance(known[catch_all_field.name], dict)
|
||||
default_value = _CatchAllUndefinedParameters._get_default(
|
||||
catch_all_field=catch_all_field)
|
||||
received_default = default_value == known[catch_all_field.name]
|
||||
|
||||
value_to_write: Any
|
||||
if received_default and len(unknown) == 0:
|
||||
value_to_write = default_value
|
||||
elif received_default and len(unknown) > 0:
|
||||
value_to_write = unknown
|
||||
elif already_parsed:
|
||||
# Did not receive default
|
||||
value_to_write = known[catch_all_field.name]
|
||||
if len(unknown) > 0:
|
||||
value_to_write.update(unknown)
|
||||
else:
|
||||
error_message = f"Received input field with " \
|
||||
f"same name as catch-all field: " \
|
||||
f"'{catch_all_field.name}': " \
|
||||
f"'{known[catch_all_field.name]}'"
|
||||
raise UndefinedParameterError(error_message)
|
||||
else:
|
||||
value_to_write = unknown
|
||||
|
||||
known[catch_all_field.name] = value_to_write
|
||||
return known
|
||||
|
||||
@staticmethod
|
||||
def _get_default(catch_all_field: Field) -> Any:
|
||||
# access to the default factory currently causes
|
||||
# a false-positive mypy error (16. Dec 2019):
|
||||
# https://github.com/python/mypy/issues/6910
|
||||
|
||||
# noinspection PyProtectedMember
|
||||
has_default = not isinstance(catch_all_field.default,
|
||||
dataclasses._MISSING_TYPE)
|
||||
# noinspection PyProtectedMember
|
||||
has_default_factory = not isinstance(catch_all_field.default_factory,
|
||||
# type: ignore
|
||||
dataclasses._MISSING_TYPE)
|
||||
# TODO: black this for proper formatting
|
||||
default_value: Union[
|
||||
Type[_CatchAllUndefinedParameters._SentinelNoDefault], Any] = _CatchAllUndefinedParameters\
|
||||
._SentinelNoDefault
|
||||
|
||||
if has_default:
|
||||
default_value = catch_all_field.default
|
||||
elif has_default_factory:
|
||||
# This might be unwanted if the default factory constructs
|
||||
# something expensive,
|
||||
# because we have to construct it again just for this test
|
||||
default_value = catch_all_field.default_factory() # type: ignore
|
||||
|
||||
return default_value
|
||||
|
||||
@staticmethod
|
||||
def handle_to_dict(obj, kvs: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||
catch_all_field = \
|
||||
_CatchAllUndefinedParameters._get_catch_all_field(obj.__class__)
|
||||
undefined_parameters = kvs.pop(catch_all_field.name)
|
||||
if isinstance(undefined_parameters, dict):
|
||||
kvs.update(
|
||||
undefined_parameters) # If desired handle letter case here
|
||||
return kvs
|
||||
|
||||
@staticmethod
|
||||
def handle_dump(obj) -> Dict[Any, Any]:
|
||||
catch_all_field = _CatchAllUndefinedParameters._get_catch_all_field(
|
||||
cls=obj)
|
||||
return getattr(obj, catch_all_field.name)
|
||||
|
||||
@staticmethod
|
||||
def create_init(obj) -> Callable:
|
||||
original_init = obj.__init__
|
||||
init_signature = inspect.signature(original_init)
|
||||
|
||||
@functools.wraps(obj.__init__)
|
||||
def _catch_all_init(self, *args, **kwargs):
|
||||
known_kwargs, unknown_kwargs = \
|
||||
_CatchAllUndefinedParameters._separate_defined_undefined_kvs(
|
||||
obj, kwargs)
|
||||
num_params_takeable = len(
|
||||
init_signature.parameters) - 1 # don't count self
|
||||
if _CatchAllUndefinedParameters._get_catch_all_field(
|
||||
obj).name not in known_kwargs:
|
||||
num_params_takeable -= 1
|
||||
num_args_takeable = num_params_takeable - len(known_kwargs)
|
||||
|
||||
args, unknown_args = args[:num_args_takeable], args[
|
||||
num_args_takeable:]
|
||||
bound_parameters = init_signature.bind_partial(self, *args,
|
||||
**known_kwargs)
|
||||
|
||||
unknown_args = {f"_UNKNOWN{i}": v for i, v in
|
||||
enumerate(unknown_args)}
|
||||
arguments = bound_parameters.arguments
|
||||
arguments.update(unknown_args)
|
||||
arguments.update(unknown_kwargs)
|
||||
arguments.pop("self", None)
|
||||
final_parameters = _CatchAllUndefinedParameters.handle_from_dict(
|
||||
obj, arguments)
|
||||
original_init(self, **final_parameters)
|
||||
|
||||
return _catch_all_init
|
||||
|
||||
@staticmethod
|
||||
def _get_catch_all_field(cls) -> Field:
|
||||
cls_globals = vars(sys.modules[cls.__module__])
|
||||
types = get_type_hints(cls, globalns=cls_globals)
|
||||
catch_all_fields = list(
|
||||
filter(lambda f: types[f.name] == Optional[CatchAllVar], fields(cls)))
|
||||
number_of_catch_all_fields = len(catch_all_fields)
|
||||
if number_of_catch_all_fields == 0:
|
||||
raise UndefinedParameterError(
|
||||
"No field of type dataclasses_json.CatchAll defined")
|
||||
elif number_of_catch_all_fields > 1:
|
||||
raise UndefinedParameterError(
|
||||
f"Multiple catch-all fields supplied: "
|
||||
f"{number_of_catch_all_fields}.")
|
||||
else:
|
||||
return catch_all_fields[0]
|
||||
|
||||
|
||||
class Undefined(Enum):
|
||||
"""
|
||||
Choose the behavior what happens when an undefined parameter is encountered
|
||||
during class initialization.
|
||||
"""
|
||||
INCLUDE = _CatchAllUndefinedParameters
|
||||
RAISE = _RaiseUndefinedParameters
|
||||
EXCLUDE = _IgnoreUndefinedParameters
|
||||
|
||||
|
||||
class UndefinedParameterError(ValidationError):
|
||||
"""
|
||||
Raised when something has gone wrong handling undefined parameters.
|
||||
"""
|
||||
pass
|
Reference in New Issue
Block a user