import dataclasses
import enum
import inspect
import json
import logging
import os
import textwrap
import typing
from collections.abc import Awaitable, Callable, Mapping
from typing import Any, TypeVar, cast
import anyio
import cattrs
import cattrs.gen
from cattrs.preconf import is_primitive_enum
from cattrs.preconf.json import JsonConverter
from typing_extensions import dataclass_transform, overload
import dagger
from dagger import dag
from dagger.client._core import configure_converter_enum
from dagger.mod._converter import make_converter, to_typedef
from dagger.mod._exceptions import (
BadUsageError,
FunctionError,
InvalidInputError,
InvalidResultError,
ObjectNotFoundError,
RegistrationError,
log_exception_only,
transform_error,
)
from dagger.mod._resolver import (
Constructor,
Field,
Func,
Function,
ObjectType,
P,
R,
)
from dagger.mod._types import APIName, FieldDefinition, FunctionDefinition, PythonName
from dagger.mod._utils import (
asyncify,
extract_enum_member_doc,
get_doc,
get_parent_module_doc,
is_annotated,
)
logger = logging.getLogger(__package__)
OBJECT_DEF_KEY: typing.Final[str] = "__dagger_object__"
FIELD_DEF_KEY: typing.Final[str] = "__dagger_field__"
FUNCTION_DEF_KEY: typing.Final[str] = "__dagger_function__"
CHECK_DEF_KEY: typing.Final[str] = "__dagger_check__"
GENERATOR_DEF_KEY: typing.Final[str] = "__dagger_generate__"
UP_DEF_KEY: typing.Final[str] = "__dagger_up__"
MODULE_NAME: typing.Final[str] = os.getenv("DAGGER_MODULE", "")
MAIN_OBJECT: typing.Final[str] = os.getenv("DAGGER_MAIN_OBJECT", "")
TYPE_DEF_FILE: typing.Final[str] = os.getenv("DAGGER_MODULE_FILE", "/module.json")
T = TypeVar("T", bound=type)
class Module:
"""Builder for a :py:class:`dagger.Module`."""
def __init__(self, main_name: str = MAIN_OBJECT):
self._main_name = main_name
self._converter: JsonConverter = make_converter()
self._objects: dict[str, ObjectType] = {}
self._enums: dict[str, type[enum.Enum]] = {}
self._main: ObjectType | None = None
# Escape hatch if there's too much noise from showing stack traces
# from exceptions raised in functions by default. Not documented
# intentionally for now.
self.log_exceptions = True
@property
def main_cls(self) -> type[ObjectType]:
assert self._main is not None
return self._main.cls
def is_main(self, other: ObjectType) -> bool:
"""Check if the given object is the main object of the module."""
return self.main_cls is other.cls
async def serve(self):
if await dag.current_function_call().parent_name():
result = await self.invoke()
else:
try:
result = await self._typedefs()
except TypeError as e:
raise RegistrationError(str(e)) from e
try:
output = json.dumps(result)
except (TypeError, ValueError) as e:
# Not expected to happen because unstructuring should reduce
# Python complex types to primitive values that are easily
# serialized to JSON. If not, it's something that should be caught
# earlier.
msg = f"Failed to serialize final result as JSON: {e}"
raise InvalidResultError(msg) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"output => %s",
textwrap.shorten(repr(output), 144),
)
await dag.current_function_call().return_value(dagger.JSON(output))
async def register(self):
"""Register the module and its types with the Dagger API."""
try:
result = await self._typedefs()
output = json.dumps(result)
except TypeError as e:
raise RegistrationError(str(e), e) from e
await anyio.Path(TYPE_DEF_FILE).write_text(output)
async def _typedefs(self) -> str: # noqa: C901, PLR0912, PLR0915
if not self._main_name:
msg = "Main object name can't be empty"
raise ValueError(msg)
try:
self.get_object(self._main_name)
except ObjectNotFoundError as e:
msg = (
f"Main object with name '{self._main_name}' not found or class not "
"decorated with '@dagger.object_type'\n"
f"If you believe the module name '{MODULE_NAME}' is incorrectly "
"being converted into PascalCase, please file a bug report."
)
raise ObjectNotFoundError(msg, extra=e.extra) from None
mod = dag.module()
# Object types
for obj_name, obj_type in self._objects.items():
if self.is_main(obj_type):
# Only the main object's constructor is needed.
# It's the entrypoint to the module.
obj_type.get_constructor(self._converter)
# Module description from main object's parent module
if desc := get_parent_module_doc(obj_type.cls):
mod = mod.with_description(desc)
# Object/interface type
type_def = dag.type_def()
if obj_type.interface:
type_def = type_def.with_interface(
obj_name,
description=get_doc(obj_type.cls),
)
else:
type_def = type_def.with_object(
obj_name,
description=get_doc(obj_type.cls),
deprecated=obj_type.deprecated,
)
# Object fields
if obj_type.fields:
types = typing.get_type_hints(obj_type.cls)
for field_name, field in obj_type.fields.items():
ctx = f"type for field '{field.original_name}' in {obj_type}"
type_def = type_def.with_field(
field_name,
to_typedef(types[field.original_name], ctx),
description=get_doc(field.return_type),
deprecated=field.meta.deprecated,
)
# Object/interface functions
for func_name, func in obj_type.functions.items():
what = f"function '{func_name}'" if func_name else "constructor"
func_def = dag.function(
func_name,
to_typedef(
func.return_type,
f"return type for {what} in {obj_type}",
),
)
if doc := func.doc:
func_def = func_def.with_description(doc)
if func.cache_policy is not None:
if func.cache_policy == "never":
func_def = func_def.with_cache_policy(
dagger.FunctionCachePolicy.Never,
)
elif func.cache_policy == "session":
func_def = func_def.with_cache_policy(
dagger.FunctionCachePolicy.PerSession,
)
elif func.cache_policy != "":
func_def = func_def.with_cache_policy(
dagger.FunctionCachePolicy.Default,
time_to_live=func.cache_policy,
)
if deprecated := func.deprecated:
func_def = func_def.with_deprecated(reason=deprecated)
if func.check:
func_def = func_def.with_check()
if func.generate:
func_def = func_def.with_generator()
if func.service:
func_def = func_def.with_up()
for param in func.parameters.values():
arg_def = to_typedef(
param.resolved_type,
f"parameter type for '{param.name}' in {what} and {obj_type}",
)
if param.is_nullable:
arg_def = arg_def.with_optional(True)
func_def = func_def.with_arg(
param.name,
arg_def,
description=param.doc,
default_value=param.default_value,
default_path=param.default_path,
default_address=param.default_address,
ignore=param.ignore,
deprecated=param.deprecated,
)
type_def = (
type_def.with_constructor(func_def)
if func_name == ""
else type_def.with_function(func_def)
)
# Add object/interface to module
mod = (
mod.with_interface(type_def)
if obj_type.interface
else mod.with_object(type_def)
)
# Enum types
for name, cls in self._enums.items():
enum_def = dag.type_def().with_enum(name, description=get_doc(cls))
member_docs = extract_enum_member_doc(cls)
for member in cls:
description = getattr(member, "description", None)
meta = member_docs.get(member.name)
if description is None and meta and meta.description is not None:
description = meta.description
enum_def = enum_def.with_enum_member(
member.name,
value=str(member.value),
description=description,
deprecated=meta.deprecated if meta else None,
)
mod = mod.with_enum(enum_def)
return await mod.id()
async def invoke(self) -> str:
"""Invoke a function and return its result.
This includes getting the call context from the API and deserializing data.
"""
fn_call = dag.current_function_call()
parent_name = await fn_call.parent_name()
if not parent_name:
msg = (
"Seems like the SDK module isn't registering the types correctly. "
"This is a bug."
)
raise RegistrationError(msg)
name = await fn_call.name()
parent_json = await fn_call.parent()
input_args = await fn_call.input_args()
parent_state: dict[str, Any] = {}
if parent_json.strip():
try:
parent_state = json.loads(parent_json) or {}
except ValueError as e:
logger.exception("Failed to decode JSON parent value")
msg = "Unable to decode the parent object's state"
extra = {
"parent_json": parent_json,
}
raise InvalidInputError(msg, extra=extra) from e
inputs = {}
for arg in input_args:
# NB: These are already loaded by `input_args`,
# the await just returns the cached value.
arg_name = await arg.name()
arg_value = await arg.value()
try:
# Cattrs can decode JSON strings but use `json` directly
# for more granular control over the error.
inputs[arg_name] = json.loads(arg_value)
except ValueError as e:
logger.exception("Failed to decode JSON input value")
msg = f"Unable to decode input argument '{arg_name}'"
extra = {
"json_value": arg_value,
}
raise InvalidInputError(msg, extra=extra) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"invoke => %s",
{
"parent_name": parent_name,
"parent_json": textwrap.shorten(parent_json, 144),
"name": name,
"input_args": textwrap.shorten(repr(inputs), 144),
},
)
result = await self.get_result(
parent_name,
parent_state,
name,
inputs,
)
if logger.isEnabledFor(logging.DEBUG):
logger.debug(
"result => %s",
textwrap.shorten(repr(result), 144),
)
return result
async def get_result(
self,
parent_name: str,
parent_state: Mapping[str, Any],
name: str,
raw_inputs: Mapping[str, Any],
) -> Any:
"""Get function result as an unstructured Python primitive."""
result, fn = await self.get_structured_result(
parent_name,
parent_state,
name,
raw_inputs,
)
if fn.return_type is not None:
try:
return await self.unstructure(result, fn.return_type)
except Exception as e:
log_exception_only(e, "Invalid result from function")
msg = transform_error(
e,
origin=getattr(fn, "wrapped", None),
typ=fn.return_type,
)
msg += (
"\n"
"Please check if the returned value at runtime matches "
"the function's declared return type."
)
raise InvalidResultError(msg) from e
return None
async def get_structured_result(
self,
parent_name: str,
parent_state: Mapping[str, Any],
name: str,
raw_inputs: Mapping[str, Any],
) -> tuple[Any, Field | Function]:
"""Execute a function and return its result as a primitive value."""
obj_type = self.get_object(parent_name)
if name == "":
fn = obj_type.get_constructor(self._converter)
else:
parent = await self._get_parent_instance(obj_type, parent_state)
# NB: fields are not executed by the SDK, they're returned directly by
# the engine, but this is still useful for testing.
if name in obj_type.fields:
f = obj_type.fields[name]
result = getattr(parent, f.original_name)
return result, f
fn = obj_type.get_bound_function(parent, name)
inputs = await self._convert_inputs(fn, raw_inputs)
bound = fn.bind_arguments(**inputs)
if logger.isEnabledFor(logging.DEBUG):
logger.debug("func => %s", repr(fn.signature))
logger.debug("input args => %s", repr(raw_inputs))
logger.debug("bound args => %s", repr(bound.arguments))
result = await self.call(fn.wrapped, *bound.args, **bound.kwargs)
# Provide better errors for missing async/await
if inspect.iscoroutine(result):
result.close() # avoid RuntimeWarning
if not inspect.iscoroutinefunction(fn.wrapped):
msg = (
f"Function '{fn}' returned a coroutine.\n"
"Did you forget to add 'async' to the function signature?"
)
else:
msg = (
f"Async function '{fn}' was never awaited.\n"
"Did you forget to add an 'await' to the return value?"
)
raise FunctionError(msg) from None
return result, fn
async def call(self, func: Func[P, R], *args: P.args, **kwargs: P.kwargs) -> R:
"""Call a function and return its result."""
try:
# We could await based on the return value instead of checking function
# color but that would silently allow incorrect code which is
# especially bad if not intentional and we don't warn user about it.
result = func(*args, **kwargs)
if inspect.iscoroutinefunction(func):
result = await cast(typing.Awaitable[R], result)
except FunctionError:
# Escape hatch to fully control logging from user code.
raise
except dagger.QueryError as e:
tb = e.__traceback__
# Exclude the line in "try" above
if tb:
tb = tb.tb_next
# Exclude the underlying TransportQueryError to reduce noise
e.__cause__ = None
logger.exception(
"API error while executing function",
exc_info=(type(e), e, tb),
)
# Preserve API error so it's properly propagated.
raise e from None
except Exception as e:
# Escape hatch if too noisy.
if self.log_exceptions:
# Logging the exception will show the full stack trace on stderr.
logger.exception("Unhandled exception while executing function")
raise FunctionError(str(e)) from e
return result
async def structure(self, obj: Any, cl: type[T]) -> T:
"""Convert a primitive value to the expected type."""
return await asyncify(self._converter.structure, obj, cl)
async def unstructure(self, obj: Any, unstructure_as: Any) -> Awaitable[Any]:
"""Convert a result to primitive values."""
return await asyncify(self._converter.unstructure, obj, unstructure_as)
def get_object(self, name: str) -> ObjectType:
"""Get the object type definition for the given name."""
try:
return self._objects[name]
except KeyError:
# Not expected to happen during invoke because registration should
# fail first.
msg = f"No '@dagger.object_type' decorated class named '{name}' was found"
extra = {"objects_found": self._objects.keys()}
raise ObjectNotFoundError(msg, extra=extra) from None
async def _get_parent_instance(
self,
obj_type: ObjectType[T],
state: Mapping[str, Any],
) -> T:
"""Instantiate the parent object from its state."""
try:
return await self.structure(state, obj_type.cls)
except Exception as e:
log_exception_only(e, "Failed to instantiate parent object")
msg = transform_error(
e,
f"Failed to instantiate parent object '{obj_type}'",
origin=obj_type.cls,
typ=obj_type.cls,
)
# If API is able to make the call this is likely a bug in the SDK.
# For example, if the registration phase reports a type that isn't
# compatible with cattrs' converter.
msg += (
"\n"
"This could be an error in the Python SDK. "
"If so, please file a bug report."
)
extra = {"object_state": state}
raise InvalidInputError(msg, extra=extra) from e
async def _convert_inputs(
self,
fn: Function,
inputs: Mapping[APIName, Any],
) -> Mapping[PythonName, Any]:
"""Convert arguments from lower level primitives to the expected types."""
kwargs = {}
# Convert arguments to the expected type.
for python_name, param in fn.parameters.items():
if param.name not in inputs:
if not param.is_optional:
msg = f"Missing required function argument '{python_name}'"
raise InvalidInputError(msg)
if param.has_default:
continue
# If the argument is optional and has no default, it's a nullable type.
# According to GraphQL spec, null is a valid value in case it's omitted.
value = inputs.get(param.name)
type_ = param.resolved_type
try:
kwargs[python_name] = await self.structure(value, type_)
except Exception as e:
log_exception_only(
e,
"Failed to convert from primitive input value for argument '%s'",
param.name,
)
msg = transform_error(
e,
(
"Failed to convert from primitive input value for argument "
f"'{param.name}'"
),
origin=fn.wrapped,
typ=type_,
)
# Same as before, the API can't reasonably hold a value that
# contradicts its type.
msg += (
"\n"
"This could be an error in the Python SDK. "
"If so, please file a bug report."
)
extra = {
"function_name": fn.original_name,
"parameter_name": python_name,
"expected_type": type_,
"actual_type": type(value),
}
raise InvalidInputError(msg, extra=extra) from e
if logger.isEnabledFor(logging.DEBUG):
logger.debug("structured args => %s", repr(kwargs))
return kwargs
def field(
self,
*,
default: Callable[[], Any] | object = ...,
name: APIName | None = None,
init: bool = True,
deprecated: str | None = None,
) -> Any:
"""Exposes an attribute as a :py:class:`dagger.FieldTypeDef`.
Should be used in a class decorated with :py:meth:`object_type`.
Example usage::
@object_type
class Foo:
bar: str = field(default="foobar")
args: list[str] = field(default=list)
Parameters
----------
default:
The default value for the field or a 0-argument callable to
initialize a field's value.
name:
An alternative name for the API. Useful to avoid conflicts with
reserved words.
init:
Whether the field should be included in the constructor.
Defaults to `True`.
deprecated:
Optional deprecation message exposed to the engine.
"""
kwargs = {}
optional = False
if default is not ...:
optional = True
kwargs["default_factory" if callable(default) else "default"] = default
return dataclasses.field(
metadata={FIELD_DEF_KEY: FieldDefinition(name, optional, deprecated)},
kw_only=True,
init=init,
repr=init, # default repr shows field as an __init__ argument
**kwargs,
)
def check(
self,
func: Func[P, R] | None = None,
) -> Func[P, R] | Callable[[Func[P, R]], Func[P, R]]:
"""Mark a function as a check.
Checks are functions that validate conditions and return void/error
to indicate pass/fail. This decorator can be combined with
:py:meth:`function`.
Example usage::
@object_type
class MyModule:
@function
@check
def lint(self) -> str:
return "All checks passed"
Parameters
----------
func:
The function to mark as a check. Should be an instance method in a
class decorated with :py:meth:`object_type`.
"""
def wrapper(fn: Func[P, R]) -> Func[P, R]:
setattr(fn, CHECK_DEF_KEY, True)
return fn
return wrapper(func) if func else wrapper
def generate(
self,
func: Func[P, R] | None = None,
) -> Func[P, R] | Callable[[Func[P, R]], Func[P, R]]:
"""Mark a function as a generator.
Generators are functions that return a Changeset representing
changes to be applied. This decorator can be combined with
:py:meth:`function`.
Example usage::
@object_type
class MyModule:
@function
@generate
def codegen(self) -> dagger.Changeset:
# Generate code and return changeset
...
Parameters
----------
func:
The function to mark as a generator. Should be an instance method in a
class decorated with :py:meth:`object_type`.
"""
def wrapper(fn: Func[P, R]) -> Func[P, R]:
setattr(fn, GENERATOR_DEF_KEY, True)
return fn
return wrapper(func) if func else wrapper
def up(
self,
func: Func[P, R] | None = None,
) -> Func[P, R] | Callable[[Func[P, R]], Func[P, R]]:
"""Mark a function as a service for ``dagger up``."""
def wrapper(fn: Func[P, R]) -> Func[P, R]:
setattr(fn, UP_DEF_KEY, True)
return fn
return wrapper(func) if func else wrapper
@overload
def function(
self,
func: Func[P, R],
*,
name: APIName | None = None,
doc: str | None = None,
deprecated: str | None = None,
) -> Func[P, R]: ...
@overload
def function(
self,
*,
name: APIName | None = None,
doc: str | None = None,
deprecated: str | None = None,
) -> Callable[[Func[P, R]], Func[P, R]]: ...
def function(
self,
func: Func[P, R] | None = None,
*,
name: APIName | None = None,
doc: str | None = None,
cache: str | None = None,
deprecated: str | None = None,
) -> Func[P, R] | Callable[[Func[P, R]], Func[P, R]]:
"""Exposes a Python function as a :py:class:`dagger.Function`.
Example usage::
@object_type
class Foo:
@function
def bar(self) -> str:
return "foobar"
Parameters
----------
func:
Should be an instance method in a class decorated with
:py:meth:`object_type`. Can be an async function or a class,
to use it's constructor.
name:
An alternative name for the API. Useful to avoid conflicts with
reserved words.
doc:
An alternative description for the API. Useful to use the
docstring for other purposes.
deprecated:
Optional deprecation message exposed to the engine.
"""
# TODO: Wrap appropriately
def wrapper(func: Func[P, R]) -> Func[P, R]:
# TODO: Use beartype to validate
assert callable(func), f"Expected a callable, got {type(func)}."
# Check if function is marked as a check or generator
check = getattr(func, CHECK_DEF_KEY, False)
generator = getattr(func, GENERATOR_DEF_KEY, False)
service = getattr(func, UP_DEF_KEY, False)
meta = FunctionDefinition(
name=name,
doc=doc,
cache=cache,
deprecated=deprecated,
check=check,
generator=generator,
service=service,
)
if inspect.isclass(func):
return Constructor(func, meta)
setattr(func, FUNCTION_DEF_KEY, meta)
return func
return wrapper(func) if func else wrapper
@overload
@dataclass_transform(
kw_only_default=True,
field_specifiers=(function, dataclasses.field, dataclasses.Field),
)
def object_type(self, cls: T, /, *, deprecated: str | None = None) -> T: ...
@overload
@dataclass_transform(
kw_only_default=True,
field_specifiers=(function, dataclasses.field, dataclasses.Field),
)
def object_type(self, *, deprecated: str | None = None) -> Callable[[T], T]: ...
def object_type(
self,
cls: T | None = None,
*,
deprecated: str | None = None,
) -> T | Callable[[T], T]:
"""Exposes a Python class as a :py:class:`dagger.ObjectTypeDef`.
Used with :py:meth:`field` and :py:meth:`function` to expose
the object's members.
Example usage::
import dagger
@dagger.object_type
class Foo:
@dagger.function
def bar(self) -> str:
return "foobar"
Parameters
----------
deprecated:
Optional deprecation message visible when introspecting the module.
"""
def wrapper(cls: T) -> T:
if not inspect.isclass(cls):
msg = f"Expected a class, got {type(cls)}"
raise BadUsageError(msg)
# Check for InitVar inside Annotated
fields = inspect.get_annotations(cls)
for name, t in fields.items():
if is_annotated(t) and isinstance(t.__origin__, dataclasses.InitVar):
# Pytohn 3.10 doesn't support `*meta* syntax
# in Annotated[init_t.type, *meta]
t.__origin__ = t.__origin__.type
msg = (
f"Field '{name}' is an InitVar wrapped in Annotated. "
f"The correct syntax is: InitVar[{t}]"
)
raise BadUsageError(msg)
wrapped = dataclasses.dataclass(kw_only=True)(cls)
return self._process_type(wrapped, deprecated=deprecated)
return wrapper(cls) if cls else wrapper
def _process_type(
self,
cls: T,
*,
interface: bool = False,
deprecated: str | None = None,
) -> T:
obj_def = ObjectType(cls, interface=interface, deprecated=deprecated)
cls.__dagger_module__ = self
cls.__dagger_object_type__ = obj_def
self._objects[cls.__name__] = obj_def
if cls.__name__ == self._main_name:
self._main = obj_def
# Find all constructors from other objects, decorated with `@mod.function`
def _is_constructor(fn) -> typing.TypeGuard[Constructor]:
return isinstance(fn, Constructor)
for _, fn in inspect.getmembers(cls, _is_constructor):
obj_def.functions[fn.name] = fn
# Find all methods decorated with `@mod.function`
def _is_function(fn) -> typing.TypeGuard[Func]:
return hasattr(fn, FUNCTION_DEF_KEY)
for _, meth in inspect.getmembers(cls, _is_function):
fn = Function(
meth,
meta=getattr(meth, FUNCTION_DEF_KEY),
origin=cls,
converter=self._converter,
)
obj_def.functions[fn.name] = fn
if interface:
return cls
# Register hooks for renaming field names in `mod.field()`.
attr_overrides = {}
# Find all fields exposed with `mod.field()`.
for field in dataclasses.fields(cls):
field_def: FieldDefinition | None
if field_def := field.metadata.get(FIELD_DEF_KEY, None):
r = Field(
meta=field_def,
original_name=field.name,
return_type=field.type,
)
if r.name != r.original_name:
attr_overrides[r.original_name] = cattrs.gen.override(rename=r.name)
obj_def.fields[r.name] = r
# Include fields that are excluded from the constructor.
self._converter.register_unstructure_hook(
cls,
cattrs.gen.make_dict_unstructure_fn(
cls,
self._converter,
_cattrs_include_init_false=True,
**attr_overrides,
),
)
self._converter.register_structure_hook(
cls,
cattrs.gen.make_dict_structure_fn(
cls,
self._converter,
_cattrs_include_init_false=True,
**attr_overrides,
),
)
return cls
@overload
def interface(self, cls: T) -> T: ...
@overload
def interface(self) -> Callable[[T], T]: ...
def interface(self, cls: T | None = None) -> T | Callable[[T], T]:
"""Exposes a Python class as a :py:class:`dagger.InterfaceTypeDef`.
Used with :py:meth:`function` to expose the interface's functions.
Example usage::
import typing
import dagger
@dager.interface
class Foo(typing.Protocol):
@dagger.function
async def bar(self) -> str: ...
"""
def wrapper(cls: T) -> T:
new_cls = typing.runtime_checkable(cls)
return self._process_type(new_cls, interface=True)
return wrapper(cls) if cls else wrapper
@overload
def enum_type(self, cls: T) -> T: ...
@overload
def enum_type(self) -> Callable[[T], T]: ...
def enum_type(self, cls: T | None = None) -> T | Callable[[T], T]:
'''Exposes a Python :py:class:`enum.Enum` as a :py:class:`dagger.EnumTypeDef`.
Example usage::
import enum
import dagger
@dagger.enum_type
class Options(enum.Enum):
"""Enumeration description"""
ONE = "ONE"
"""Description for the first value"""
TWO = "TWO"
"""Description for the second value"""
'''
def wrapper(cls: T) -> T:
if not inspect.isclass(cls):
msg = f"Expected an enum.Enum subclass, got {type(cls)}"
raise BadUsageError(msg)
if not issubclass(cls, enum.Enum):
msg = f"Class '{cls.__name__}' is not an enum.Enum subclass"
raise BadUsageError(msg)
cls = cast(T, enum.unique(cls))
self._enums.setdefault(cls.__name__, cls)
# Primitive enums get converted based on their primitive type rather
# than the custom hook for converting based on member names so we
# need to register the hooks for each specific class. Not necessary
# to add hooks for non-primitive enums because those are already
# handled by the general enum.Enum subclass check.
if is_primitive_enum(cls):
configure_converter_enum(self._converter, cls)
return cls
return wrapper(cls) if cls else wrapper