diff --git a/enochecker3/__init__.py b/enochecker3/__init__.py index fe7a25a..464efc6 100644 --- a/enochecker3/__init__.py +++ b/enochecker3/__init__.py @@ -2,6 +2,7 @@ from .enochecker import ( AsyncSocket, CircularDependencyException, + DependencyInjector, Enochecker, EnocheckerException, InvalidVariantIdsException, diff --git a/enochecker3/enochecker.py b/enochecker3/enochecker.py index 1974e59..5005b79 100644 --- a/enochecker3/enochecker.py +++ b/enochecker3/enochecker.py @@ -6,6 +6,7 @@ import traceback from contextlib import AsyncExitStack, asynccontextmanager from inspect import Parameter, isawaitable, signature +from types import TracebackType from typing import ( Any, AsyncContextManager, @@ -16,6 +17,7 @@ Optional, Set, Tuple, + Type, Union, cast, ) @@ -79,6 +81,42 @@ class InvalidVariantIdsException(EnocheckerException): pass +class DependencyInjector: + def __init__(self, checker: "Enochecker", task: BaseCheckerTaskMessage): + self.checker = checker + self.task = task + self._exit_stack: AsyncExitStack = AsyncExitStack() + + async def __aenter__(self) -> "DependencyInjector": + await self._exit_stack.__aenter__() + return self + + async def __aexit__( + self, + exc_type: Optional[Type[BaseException]], + exc_value: Optional[BaseException], + traceback: Optional[TracebackType], + ) -> Optional[bool]: + return await self._exit_stack.__aexit__(exc_type, exc_value, traceback) + + async def get(self, t: type) -> Any: + if t not in self.checker._dependency_injections: + raise ValueError(f"No registered dependency for type {t}") + + injector = self.checker._dependency_injections[t] + args = await self._exit_stack.enter_async_context( + self.checker._inject_dependencies(self.task, injector, None) + ) + if isawaitable(injector): + res = await injector(*args) + else: + res = injector(*args) + + if not hasattr(res, "__enter__") and not hasattr(res, "__aenter__"): + return res + return await self._exit_stack.enter_async_context(res) + + class Enochecker: def __init__(self, name: str, service_port: int): self.name: str = name @@ -105,6 +143,7 @@ def __init__(self, name: str, service_port: int): self.register_dependency(self._get_flag_searcher) self.register_dependency(self._get_logger_adapter) self.register_dependency(self._get_async_socket) + self.register_dependency(self._get_dependency_injector) self._method_variants: Dict[CheckerMethod, Dict[int, Callable[..., Any]]] = { CheckerMethod.PUTFLAG: {}, @@ -338,6 +377,11 @@ async def _get_async_socket(self, task: BaseCheckerTaskMessage) -> AsyncSocket: conn[1].close() await conn[1].wait_closed() + def _get_dependency_injector( + self, task: BaseCheckerTaskMessage + ) -> DependencyInjector: + return DependencyInjector(self, task) + ######################### # variant_id validation # ######################### diff --git a/setup.py b/setup.py index aba8ccc..471051a 100755 --- a/setup.py +++ b/setup.py @@ -9,7 +9,7 @@ setuptools.setup( name="enochecker3", - version="0.1.1", + version="0.2.0", author="ldruschk", author_email="ldruschk@posteo.de", description="FastAPI based library for building async python checkers for the EnoEngine A/D CTF Framework",