95 lines
2.7 KiB
Python
95 lines
2.7 KiB
Python
"""A registry of :class:`Schema <marshmallow.Schema>` classes. This allows for string
|
|
lookup of schemas, which may be used with
|
|
class:`fields.Nested <marshmallow.fields.Nested>`.
|
|
|
|
.. warning::
|
|
|
|
This module is treated as private API.
|
|
Users should not need to use this module directly.
|
|
"""
|
|
|
|
from __future__ import annotations
|
|
|
|
import typing
|
|
|
|
from marshmallow.exceptions import RegistryError
|
|
|
|
if typing.TYPE_CHECKING:
|
|
from marshmallow import Schema
|
|
|
|
SchemaType = type[Schema]
|
|
|
|
# {
|
|
# <class_name>: <list of class objects>
|
|
# <module_path_to_class>: <list of class objects>
|
|
# }
|
|
_registry = {} # type: dict[str, list[SchemaType]]
|
|
|
|
|
|
def register(classname: str, cls: SchemaType) -> None:
|
|
"""Add a class to the registry of serializer classes. When a class is
|
|
registered, an entry for both its classname and its full, module-qualified
|
|
path are added to the registry.
|
|
|
|
Example: ::
|
|
|
|
class MyClass:
|
|
pass
|
|
|
|
|
|
register("MyClass", MyClass)
|
|
# Registry:
|
|
# {
|
|
# 'MyClass': [path.to.MyClass],
|
|
# 'path.to.MyClass': [path.to.MyClass],
|
|
# }
|
|
|
|
"""
|
|
# Module where the class is located
|
|
module = cls.__module__
|
|
# Full module path to the class
|
|
# e.g. user.schemas.UserSchema
|
|
fullpath = ".".join([module, classname])
|
|
# If the class is already registered; need to check if the entries are
|
|
# in the same module as cls to avoid having multiple instances of the same
|
|
# class in the registry
|
|
if classname in _registry and not any(
|
|
each.__module__ == module for each in _registry[classname]
|
|
):
|
|
_registry[classname].append(cls)
|
|
elif classname not in _registry:
|
|
_registry[classname] = [cls]
|
|
|
|
# Also register the full path
|
|
if fullpath not in _registry:
|
|
_registry.setdefault(fullpath, []).append(cls)
|
|
else:
|
|
# If fullpath does exist, replace existing entry
|
|
_registry[fullpath] = [cls]
|
|
return None
|
|
|
|
|
|
def get_class(classname: str, all: bool = False) -> list[SchemaType] | SchemaType:
|
|
"""Retrieve a class from the registry.
|
|
|
|
:raises: marshmallow.exceptions.RegistryError if the class cannot be found
|
|
or if there are multiple entries for the given class name.
|
|
"""
|
|
try:
|
|
classes = _registry[classname]
|
|
except KeyError as error:
|
|
raise RegistryError(
|
|
f"Class with name {classname!r} was not found. You may need "
|
|
"to import the class."
|
|
) from error
|
|
if len(classes) > 1:
|
|
if all:
|
|
return _registry[classname]
|
|
raise RegistryError(
|
|
f"Multiple classes with name {classname!r} "
|
|
"were found. Please use the full, "
|
|
"module-qualified path."
|
|
)
|
|
else:
|
|
return _registry[classname][0]
|