Source code for rodi

import contextvars
import inspect
import re
import sys
from collections import defaultdict
from enum import Enum
from inspect import Signature, _empty, isabstract, isclass, iscoroutinefunction
from typing import (
    Any,
    Callable,
    ClassVar,
    DefaultDict,
    Dict,
    Mapping,
    Optional,
    Set,
    Type,
    TypeVar,
    Union,
    cast,
    get_type_hints,
)

if sys.version_info >= (3, 8):  # pragma: no cover
    try:
        from typing import _no_init_or_replace_init as _no_init
    except ImportError:  # pragma: no cover
        from typing import _no_init

try:
    from typing import Protocol
except ImportError:  # pragma: no cover
    from typing_extensions import Protocol


T = TypeVar("T")


class ContainerProtocol(Protocol):
    """
    Generic interface of DI Container that can register and resolve services,
    and tell if a type is configured.
    """

    def register(self, obj_type: Union[Type, str], *args, **kwargs):
        """Registers a type in the container, with optional arguments."""

    def resolve(self, obj_type: Union[Type[T], str], *args, **kwargs) -> T:
        """Activates an instance of the given type, with optional arguments."""

    def __contains__(self, item) -> bool:
        """
        Returns a value indicating whether a given type is configured in this container.
        """


AliasesTypeHint = Dict[str, Type]


def inject(globalsns=None, localns=None) -> Callable[..., Any]:
    """
    Marks a class or a function as injected. This method is only necessary if the class
    uses locals and the user uses Python >= 3.10, to bind the function's locals to the
    factory.
    """
    if localns is None or globalsns is None:
        frame = inspect.currentframe()
        try:
            if localns is None:
                localns = frame.f_back.f_locals  # type: ignore
            if globalsns is None:
                globalsns = frame.f_back.f_globals  # type: ignore
        finally:
            del frame

    def decorator(f):
        f._locals = localns
        f._globals = globalsns
        return f

    return decorator


def _get_obj_locals(obj) -> Optional[Dict[str, Any]]:
    return getattr(obj, "_locals", None)


def class_name(input_type):
    if input_type in {list, set} and str(  # noqa: E721
        type(input_type) == "<class 'types.GenericAlias'>"
    ):
        # for Python 3.9 list[T], set[T]
        return str(input_type)
    try:
        return input_type.__name__
    except AttributeError:
        # for example, this is the case for List[str], Tuple[str, ...], etc.
        return str(input_type)


[docs] class DIException(Exception): """Base exception class for DI exceptions."""
[docs] class FactoryMissingContextException(DIException):
[docs] def __init__(self, function) -> None: super().__init__( f"The factory '{function.__name__}' lacks locals and globals data. " "Decorate the function with the `@inject()` decorator defined in " "`rodi`. This is necessary since PEP 563." )
[docs] class CannotResolveTypeException(DIException): """ Exception risen when it is not possible to resolve a Type."""
[docs] def __init__(self, desired_type): super().__init__(f"Unable to resolve the type '{desired_type}'.")
[docs] class CannotResolveParameterException(DIException): """ Exception risen when it is not possible to resolve a parameter, necessary to instantiate a type."""
[docs] def __init__(self, param_name, desired_type): super().__init__( f"Unable to resolve parameter '{param_name}' " f"when resolving '{class_name(desired_type)}'" )
[docs] class OverridingServiceException(DIException): """ Exception risen when registering a service would override an existing one."""
[docs] def __init__(self, key, value): key_name = key if isinstance(key, str) else class_name(key) super().__init__( f"A service with key '{key_name}' is already " f"registered and would be overridden by value {value}." )
[docs] class CircularDependencyException(DIException): """Exception risen when a circular dependency between a type and one of its parameters is detected."""
[docs] def __init__(self, expected_type, desired_type): super().__init__( "A circular dependency was detected for the service " f"of type '{class_name(expected_type)}' " f"for '{class_name(desired_type)}'" )
class InvalidOperationInStrictMode(DIException): def __init__(self): super().__init__( "The services are configured in strict mode, the operation is invalid." ) class AliasAlreadyDefined(DIException): """Exception risen when trying to add an alias that already exists.""" def __init__(self, name): super().__init__( f"Cannot define alias '{name}'. " f"An alias with given name is already defined." ) class AliasConfigurationError(DIException): def __init__(self, name, _type): super().__init__( f"An alias '{name}' for type '{class_name(_type)}' was defined, " f"but the type was not configured in the Container." )
[docs] class MissingTypeException(DIException): """Exception risen when a type must be specified to use a factory"""
[docs] def __init__(self): super().__init__( "Please specify the factory return type or " "annotate its return type; func() -> Foo:" )
class InvalidFactory(DIException): """Exception risen when a factory is not valid""" def __init__(self, _type): super().__init__( f"The factory specified for type {class_name(_type)} is not " f"valid, it must be a function with either these signatures: " f"def example_factory(context, type): " f"or," f"def example_factory(context): " f"or," f"def example_factory(): " ) class ServiceLifeStyle(Enum): TRANSIENT = 1 SCOPED = 2 SINGLETON = 3 def _get_factory_annotations_or_throw(factory): factory_locals = getattr(factory, "_locals", None) factory_globals = getattr(factory, "_globals", None) if factory_locals is None: raise FactoryMissingContextException(factory) return get_type_hints(factory, globalns=factory_globals, localns=factory_locals) class ActivationScope: __slots__ = ("scoped_services", "provider") def __init__( self, provider: Optional["Services"] = None, scoped_services: Optional[Dict[Union[Type[T], str], T]] = None, ): self.provider = provider or Services() self.scoped_services = scoped_services or {} def __enter__(self): if self.scoped_services is None: self.scoped_services = {} return self def __exit__(self, exc_type, exc_val, exc_tb): self.dispose() def get( self, desired_type: Union[Type[T], str], scope: Optional["ActivationScope"] = None, *, default: Optional[Any] = ..., ) -> T: if self.provider is None: raise TypeError("This scope is disposed.") return self.provider.get(desired_type, scope or self, default=default) def dispose(self): if self.provider: self.provider = None if self.scoped_services: self.scoped_services.clear() self.scoped_services = None class TrackingActivationScope(ActivationScope): """ This is an experimental class to support nested scopes transparently. To use it, create a container including the `scope_cls` parameter: `Container(scope_cls=TrackingActivationScope)`. """ _active_scopes = contextvars.ContextVar("active_scopes", default=[]) __slots__ = ("scoped_services", "provider", "parent_scope") def __init__(self, provider=None, scoped_services=None): # Get the current stack of active scopes stack = self._active_scopes.get() # Detect the parent scope if it exists self.parent_scope = stack[-1] if stack else None # Initialize scoped services scoped_services = scoped_services or {} if self.parent_scope: scoped_services.update(self.parent_scope.scoped_services) super().__init__(provider, scoped_services) def __enter__(self): # Push this scope onto the stack stack = self._active_scopes.get() self._active_scopes.set(stack + [self]) return self def __exit__(self, exc_type, exc_val, exc_tb): # Pop this scope from the stack stack = self._active_scopes.get() self._active_scopes.set(stack[:-1]) self.dispose() def dispose(self): if self.provider: self.provider = None class ResolutionContext: __slots__ = ("resolved", "dynamic_chain") __deletable__ = ("resolved",) def __init__(self): self.resolved = {} self.dynamic_chain = [] def __enter__(self): return self def __exit__(self, exc_type, exc_val, exc_tb): self.dispose() def dispose(self): del self.resolved self.dynamic_chain.clear() class InstanceProvider: __slots__ = ("instance",) def __init__(self, instance): self.instance = instance def __call__(self, context, parent_type): return self.instance class TypeProvider: __slots__ = ("_type",) def __init__(self, _type): self._type = _type def __call__(self, context, parent_type): return self._type() class ScopedTypeProvider: __slots__ = ("_type",) def __init__(self, _type): self._type = _type def __call__(self, context: ActivationScope, parent_type): if self._type in context.scoped_services: return context.scoped_services[self._type] service = self._type() context.scoped_services[self._type] = service return service class ArgsTypeProvider: __slots__ = ("_type", "_args_callbacks") def __init__(self, _type, args_callbacks): self._type = _type self._args_callbacks = args_callbacks def __call__(self, context, parent_type): return self._type(*[fn(context, self._type) for fn in self._args_callbacks]) class FactoryTypeProvider: __slots__ = ("_type", "factory") def __init__(self, _type, factory): self._type = _type self.factory = factory def __call__(self, context: ActivationScope, parent_type): assert isinstance(context, ActivationScope) return self.factory(context, parent_type) class SingletonFactoryTypeProvider: __slots__ = ("_type", "factory", "instance") def __init__(self, _type, factory): self._type = _type self.factory = factory self.instance = None def __call__(self, context: ActivationScope, parent_type): if self.instance is None: self.instance = self.factory(context, parent_type) return self.instance class ScopedFactoryTypeProvider: __slots__ = ("_type", "factory") def __init__(self, _type, factory): self._type = _type self.factory = factory def __call__(self, context: ActivationScope, parent_type): if self._type in context.scoped_services: return context.scoped_services[self._type] instance = self.factory(context, parent_type) context.scoped_services[self._type] = instance return instance class ScopedArgsTypeProvider: __slots__ = ("_type", "_args_callbacks") def __init__(self, _type, args_callbacks): self._type = _type self._args_callbacks = args_callbacks def __call__(self, context: ActivationScope, parent_type): if self._type in context.scoped_services: return context.scoped_services[self._type] service = self._type(*[fn(context, self._type) for fn in self._args_callbacks]) context.scoped_services[self._type] = service return service class SingletonTypeProvider: __slots__ = ("_type", "_instance", "_args_callbacks") def __init__(self, _type, _args_callbacks): self._type = _type self._args_callbacks = _args_callbacks self._instance = None def __call__(self, context, parent_type): if self._instance is None: self._instance = ( self._type(*[fn(context, self._type) for fn in self._args_callbacks]) if self._args_callbacks else self._type() ) return self._instance def get_annotations_type_provider( concrete_type: Type, resolvers: Mapping[str, Callable], life_style: ServiceLifeStyle, resolver_context: ResolutionContext, ): def factory(context, parent_type): instance = concrete_type() for name, resolver in resolvers.items(): setattr(instance, name, resolver(context, parent_type)) return instance return FactoryResolver(concrete_type, factory, life_style)(resolver_context) def _get_plain_class_factory(concrete_type: Type): def factory(*args): return concrete_type() return factory class InstanceResolver: __slots__ = ("instance",) def __init__(self, instance): self.instance = instance def __repr__(self): return f"<Singleton {class_name(self.instance.__class__)}>" def __call__(self, context: ResolutionContext): return InstanceProvider(self.instance) class Dependency: __slots__ = ("name", "annotation") def __init__(self, name, annotation): self.name = name self.annotation = annotation class DynamicResolver: __slots__ = ("_concrete_type", "services", "life_style") def __init__(self, concrete_type, services, life_style): assert isclass(concrete_type) assert not isabstract(concrete_type) self._concrete_type = concrete_type self.services = services self.life_style = life_style @property def concrete_type(self) -> Type: return self._concrete_type def _get_resolver(self, desired_type, context: ResolutionContext): # NB: the following two lines are important to ensure that singletons # are instantiated only once per service provider # to not repeat operations more than once if desired_type in context.resolved: return context.resolved[desired_type] reg = self.services._map.get(desired_type) assert ( reg is not None ), f"A resolver for type {class_name(desired_type)} is not configured" resolver = reg(context) # add the resolver to the context, so we can find it # next time we need it context.resolved[desired_type] = resolver return resolver def _get_resolvers_for_parameters( self, concrete_type, context: ResolutionContext, params: Mapping[str, Dependency], ): fns = [] services = self.services for param_name, param in params.items(): if param_name in ("self", "args", "kwargs"): continue param_type = param.annotation if param_type is _empty: if services.strict: raise CannotResolveParameterException(param_name, concrete_type) # support for exact, user defined aliases, without ambiguity exact_alias = services._exact_aliases.get(param_name) if exact_alias: param_type = exact_alias else: aliases = services._aliases[param_name] if aliases: assert ( len(aliases) == 1 ), "Configured aliases cannot be ambiguous" for param_type in aliases: break if param_type not in services._map: raise CannotResolveParameterException(param_name, concrete_type) param_resolver = self._get_resolver(param_type, context) fns.append(param_resolver) return fns def _resolve_by_init_method(self, context: ResolutionContext): sig = Signature.from_callable(self.concrete_type.__init__) params = { key: Dependency(key, value.annotation) for key, value in sig.parameters.items() } if sys.version_info >= (3, 10): # pragma: no cover # Python 3.10 annotations = get_type_hints( self.concrete_type.__init__, vars(sys.modules[self.concrete_type.__module__]), _get_obj_locals(self.concrete_type), ) for key, value in params.items(): if key in annotations: value.annotation = annotations[key] concrete_type = self.concrete_type if len(params) == 1 and next(iter(params.keys())) == "self": if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonTypeProvider(concrete_type, None) if self.life_style == ServiceLifeStyle.SCOPED: return ScopedTypeProvider(concrete_type) return TypeProvider(concrete_type) fns = self._get_resolvers_for_parameters(concrete_type, context, params) if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonTypeProvider(concrete_type, fns) if self.life_style == ServiceLifeStyle.SCOPED: return ScopedArgsTypeProvider(concrete_type, fns) return ArgsTypeProvider(concrete_type, fns) def _ignore_class_attribute(self, key: str, value) -> bool: """ Returns a value indicating whether a class attribute should be ignored for dependency resolution, by name and value. It's ignored if it's a ClassVar or if it's already initialized explicitly. """ is_classvar = getattr(value, "__origin__", None) is ClassVar is_initialized = getattr(self.concrete_type, key, None) is not None return is_classvar or is_initialized def _has_default_init(self): init = getattr(self.concrete_type, "__init__", None) if init is object.__init__: return True if sys.version_info >= (3, 8): # pragma: no cover if init is _no_init: return True return False def _resolve_by_annotations( self, context: ResolutionContext, annotations: Dict[str, Type] ): params = { key: Dependency(key, value) for key, value in annotations.items() if not self._ignore_class_attribute(key, value) } concrete_type = self.concrete_type fns = self._get_resolvers_for_parameters(concrete_type, context, params) resolvers = {} for i, name in enumerate(params.keys()): resolvers[name] = fns[i] return get_annotations_type_provider( self.concrete_type, resolvers, self.life_style, context ) def __call__(self, context: ResolutionContext): concrete_type = self.concrete_type chain = context.dynamic_chain chain.append(concrete_type) if self._has_default_init(): annotations = get_type_hints( concrete_type, vars(sys.modules[concrete_type.__module__]), _get_obj_locals(concrete_type), ) if annotations: try: return self._resolve_by_annotations(context, annotations) except RecursionError: raise CircularDependencyException(chain[0], concrete_type) return FactoryResolver( concrete_type, _get_plain_class_factory(concrete_type), self.life_style )(context) try: return self._resolve_by_init_method(context) except RecursionError: raise CircularDependencyException(chain[0], concrete_type) class FactoryResolver: __slots__ = ("concrete_type", "factory", "params", "life_style") def __init__(self, concrete_type, factory, life_style): self.factory = factory self.concrete_type = concrete_type self.life_style = life_style def __call__(self, context: ResolutionContext): if self.life_style == ServiceLifeStyle.SINGLETON: return SingletonFactoryTypeProvider(self.concrete_type, self.factory) if self.life_style == ServiceLifeStyle.SCOPED: return ScopedFactoryTypeProvider(self.concrete_type, self.factory) return FactoryTypeProvider(self.concrete_type, self.factory) first_cap_re = re.compile("(.)([A-Z][a-z]+)") all_cap_re = re.compile("([a-z0-9])([A-Z])") def to_standard_param_name(name): value = all_cap_re.sub(r"\1_\2", first_cap_re.sub(r"\1_\2", name)).lower() if value.startswith("i_"): return "i" + value[2:] return value
[docs] class Services: """ Provides methods to activate instances of classes, by cached activator functions. """ __slots__ = ("_map", "_executors", "_scope_cls")
[docs] def __init__( self, services_map=None, scope_cls: Optional[Type[ActivationScope]] = None, ): if services_map is None: services_map = {} self._map = services_map self._executors = {} self._scope_cls = scope_cls or ActivationScope
def __contains__(self, item): return item in self._map def __getitem__(self, item): return self.get(item) def __setitem__(self, key, value): self.set(key, value)
[docs] def create_scope( self, scoped: Optional[Dict[Union[Type, str], Any]] = None ) -> ActivationScope: return self._scope_cls(self, scoped)
[docs] def set(self, new_type: Union[Type, str], value: Any): """ Sets a new service of desired type, as singleton. This method exists to increase interoperability of Services class (with dict). :param new_type: :param value: :return: """ type_name = class_name(new_type) if new_type in self._map or ( not isinstance(new_type, str) and type_name in self._map ): raise OverridingServiceException(self._map[new_type], new_type) def resolver(context, desired_type): return value self._map[new_type] = resolver if not isinstance(new_type, str): self._map[type_name] = resolver
[docs] def get( self, desired_type: Union[Type[T], str], scope: Optional[ActivationScope] = None, *, default: Optional[Any] = ..., ) -> T: """ Gets a service of the desired type, returning an activated instance. :param desired_type: desired service type. :param context: optional context, used to handle scoped services. :return: an instance of the desired type """ if scope is None: scope = self.create_scope() resolver = self._map.get(desired_type) scoped_service = scope.scoped_services.get(desired_type) if scope else None if not resolver and not scoped_service: if default is not ...: return cast(T, default) raise CannotResolveTypeException(desired_type) return cast(T, scoped_service or resolver(scope, desired_type))
def _get_getter(self, key, param): if param.annotation is _empty: def getter(context): return self.get(key, context) else: def getter(context): return self.get(param.annotation, context) getter.__name__ = f"<getter {key}>" return getter
[docs] def get_executor(self, method: Callable) -> Callable: sig = Signature.from_callable(method) params = { key: Dependency(key, value.annotation) for key, value in sig.parameters.items() } if sys.version_info >= (3, 10): # pragma: no cover # Python 3.10 annotations = _get_factory_annotations_or_throw(method) for key, value in params.items(): if key in annotations: value.annotation = annotations[key] fns = [] for key, value in params.items(): fns.append(self._get_getter(key, value)) if iscoroutinefunction(method): async def async_executor( scoped: Optional[Dict[Union[Type, str], Any]] = None, ): with self.create_scope(scoped) as context: return await method(*[fn(context) for fn in fns]) return async_executor def executor(scoped: Optional[Dict[Union[Type, str], Any]] = None): with self.create_scope(scoped) as context: return method(*[fn(context) for fn in fns]) return executor
[docs] def exec( self, method: Callable, scoped: Optional[Dict[Type, Any]] = None, ) -> Any: try: executor = self._executors[method] except KeyError: executor = self.get_executor(method) self._executors[method] = executor return executor(scoped)
FactoryCallableNoArguments = Callable[[], Any] FactoryCallableSingleArgument = Callable[[ActivationScope], Any] FactoryCallableTwoArguments = Callable[[ActivationScope, Type], Any] FactoryCallableType = Union[ FactoryCallableNoArguments, FactoryCallableSingleArgument, FactoryCallableTwoArguments, ] class FactoryWrapperNoArgs: __slots__ = ("factory",) def __init__(self, factory): self.factory = factory def __call__(self, context, activating_type): return self.factory() class FactoryWrapperContextArg: __slots__ = ("factory",) def __init__(self, factory): self.factory = factory def __call__(self, context, activating_type): return self.factory(context) class Container(ContainerProtocol): """ Configuration class for a collection of services. """ __slots__ = ("_map", "_aliases", "_exact_aliases", "_scope_cls", "strict") def __init__( self, *, strict: bool = False, scope_cls: Optional[Type[ActivationScope]] = None, ): self._map: Dict[Type, Callable] = {} self._aliases: DefaultDict[str, Set[Type]] = defaultdict(set) self._exact_aliases: Dict[str, Type] = {} self._provider: Optional[Services] = None self._scope_cls = scope_cls self.strict = strict @property def provider(self) -> Services: if self._provider is None: self._provider = self.build_provider() return self._provider def __iter__(self): yield from self._map.items() def __contains__(self, key): return key in self._map def bind_types( self, obj_type: Any, concrete_type: Any = None, life_style: ServiceLifeStyle = ServiceLifeStyle.TRANSIENT, ): try: assert issubclass(concrete_type, obj_type), ( f"Cannot register {class_name(obj_type)} for abstract class " f"{class_name(concrete_type)}" ) except TypeError: # ignore, this happens with generic types pass self._bind(obj_type, DynamicResolver(concrete_type, self, life_style)) return self def register( self, obj_type: Any, sub_type: Any = None, instance: Any = None, *args, **kwargs, ) -> "Container": """ Registers a type in this container. """ if instance is not None: self.add_instance(instance, declared_class=obj_type) return self if sub_type is None: self._add_exact_transient(obj_type) else: self.add_transient(obj_type, sub_type) return self def resolve( self, obj_type: Union[Type[T], str], scope: Any = None, *args, **kwargs, ) -> T: """ Resolves a service by type, obtaining an instance of that type. """ return self.provider.get(obj_type, scope=scope) def add_alias(self, name: str, desired_type: Type): """ Adds an alias to the set of inferred aliases. :param name: parameter name :param desired_type: desired type by parameter name :return: self """ if self.strict: raise InvalidOperationInStrictMode() if name in self._aliases or name in self._exact_aliases: raise AliasAlreadyDefined(name) self._aliases[name].add(desired_type) return self def add_aliases(self, values: AliasesTypeHint): """ Adds aliases to the set of inferred aliases. :param values: mapping object (parameter name: class) :return: self """ for key, value in values.items(): self.add_alias(key, value) return self def set_alias(self, name: str, desired_type: Type, override: bool = False): """ Sets an exact alias for a desired type. :param name: parameter name :param desired_type: desired type by parameter name :param override: whether to override existing values, or throw exception :return: self """ if self.strict: raise InvalidOperationInStrictMode() if not override and name in self._exact_aliases: raise AliasAlreadyDefined(name) self._exact_aliases[name] = desired_type return self def set_aliases(self, values: AliasesTypeHint, override: bool = False): """Sets many exact aliases for desired types. :param values: mapping object (parameter name: class) :param override: whether to override existing values, or throw exception :return: self """ for key, value in values.items(): self.set_alias(key, value, override) return self def _bind(self, key: Type, value: Any) -> None: if key in self._map: raise OverridingServiceException(key, value) self._map[key] = value if self._provider is not None: self._provider = None key_name = class_name(key) if self.strict or "." in key_name: return self._aliases[key_name].add(key) self._aliases[key_name.lower()].add(key) self._aliases[to_standard_param_name(key_name)].add(key) def add_instance( self, instance: Any, declared_class: Optional[Type] = None ) -> "Container": """ Registers an exact instance, optionally by declared class. :param instance: singleton to be registered :param declared_class: optionally, lets define the class used as reference of the singleton :return: the service collection itself """ self._bind( instance.__class__ if not declared_class else declared_class, InstanceResolver(instance), ) return self def add_singleton( self, base_type: Type, concrete_type: Optional[Type] = None ) -> "Container": """ Registers a type by base type, to be instantiated with singleton lifetime. If a single type is given, the method `add_exact_singleton` is used. :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class :return: the service collection itself """ if concrete_type is None: return self._add_exact_singleton(base_type) return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SINGLETON) def add_scoped( self, base_type: Type, concrete_type: Optional[Type] = None ) -> "Container": """ Registers a type by base type, to be instantiated with scoped lifetime. If a single type is given, the method `add_exact_scoped` is used. :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class :return: the service collection itself """ if concrete_type is None: return self._add_exact_scoped(base_type) return self.bind_types(base_type, concrete_type, ServiceLifeStyle.SCOPED) def add_transient( self, base_type: Type, concrete_type: Optional[Type] = None ) -> "Container": """ Registers a type by base type, to be instantiated with transient lifetime. If a single type is given, the method `add_exact_transient` is used. :param base_type: registered type. If a concrete type is provided, it must inherit the base type. :param concrete_type: concrete class :return: the service collection itself """ if concrete_type is None: return self._add_exact_transient(base_type) return self.bind_types(base_type, concrete_type, ServiceLifeStyle.TRANSIENT) def _add_exact_singleton(self, concrete_type: Type) -> "Container": """ Registers an exact type, to be instantiated with singleton lifetime. :param concrete_type: concrete class :return: the service collection itself """ assert not isabstract(concrete_type) self._bind( concrete_type, DynamicResolver(concrete_type, self, ServiceLifeStyle.SINGLETON), ) return self def _add_exact_scoped(self, concrete_type: Type) -> "Container": """ Registers an exact type, to be instantiated with scoped lifetime. :param concrete_type: concrete class :return: the service collection itself """ assert not isabstract(concrete_type) self._bind( concrete_type, DynamicResolver(concrete_type, self, ServiceLifeStyle.SCOPED) ) return self def _add_exact_transient(self, concrete_type: Type) -> "Container": """ Registers an exact type, to be instantiated with transient lifetime. :param concrete_type: concrete class :return: the service collection itself """ assert not isabstract(concrete_type) self._bind( concrete_type, DynamicResolver(concrete_type, self, ServiceLifeStyle.TRANSIENT), ) return self def add_singleton_by_factory( self, factory: FactoryCallableType, return_type: Optional[Type] = None ) -> "Container": self.register_factory(factory, return_type, ServiceLifeStyle.SINGLETON) return self def add_transient_by_factory( self, factory: FactoryCallableType, return_type: Optional[Type] = None ) -> "Container": self.register_factory(factory, return_type, ServiceLifeStyle.TRANSIENT) return self def add_scoped_by_factory( self, factory: FactoryCallableType, return_type: Optional[Type] = None ) -> "Container": self.register_factory(factory, return_type, ServiceLifeStyle.SCOPED) return self @staticmethod def _check_factory(factory, signature, handled_type) -> Callable: assert callable(factory), "The factory must be callable" params_len = len(signature.parameters) if params_len == 0: return FactoryWrapperNoArgs(factory) if params_len == 1: return FactoryWrapperContextArg(factory) if params_len == 2: return factory raise InvalidFactory(handled_type) def register_factory( self, factory: Callable, return_type: Optional[Type], life_style: ServiceLifeStyle, ) -> None: if not callable(factory): raise InvalidFactory(return_type) sign = Signature.from_callable(factory) if return_type is None: if sign.return_annotation is _empty: raise MissingTypeException() return_type = sign.return_annotation if isinstance(return_type, str): # pragma: no cover # Python 3.10 annotations = _get_factory_annotations_or_throw(factory) return_type = annotations["return"] self._bind( return_type, # type: ignore FactoryResolver( return_type, self._check_factory(factory, sign, return_type), life_style ), ) def build_provider(self) -> Services: """ Builds and returns a service provider that can be used to activate and obtain services. The configuration of services is validated at this point, if any service cannot be instantiated due to missing dependencies, an exception is thrown inside this operation. :return: Service provider that can be used to activate and obtain services. """ with ResolutionContext() as context: _map: Dict[Union[str, Type], Type] = {} for _type, resolver in self._map.items(): if isinstance(resolver, DynamicResolver): context.dynamic_chain.clear() if _type in context.resolved: # assert _type not in context.resolved, "_map keys must be unique" # check if its in the map if _type in _map: # NB: do not call resolver if one was already prepared for the # type raise OverridingServiceException(_type, resolver) else: resolved = context.resolved[_type] else: # add to context so that we don't repeat operations resolved = resolver(context) context.resolved[_type] = resolved _map[_type] = resolved type_name = class_name(_type) if "." not in type_name: _map[type_name] = _map[_type] if not self.strict: assert self._aliases is not None assert self._exact_aliases is not None # include aliases in the map; for name, _types in self._aliases.items(): for _type in _types: break _map[name] = self._get_alias_target_type(name, _map, _type) for name, _type in self._exact_aliases.items(): _map[name] = self._get_alias_target_type(name, _map, _type) return Services(_map, scope_cls=self._scope_cls) @staticmethod def _get_alias_target_type(name, _map, _type): try: return _map[_type] except KeyError: raise AliasConfigurationError(name, _type)