From 76912a69e2040d89766885b202a99b5ec0a9f1d9 Mon Sep 17 00:00:00 2001 From: <> Date: Mon, 18 Sep 2023 15:05:12 +0000 Subject: [PATCH] Deployed d8cb061 with MkDocs version: 1.3.0 --- .nojekyll | 0 404.html | 529 + advanced/dependency-injection/index.html | 812 + advanced/metadata/index.html | 755 + advanced/middleware/index.html | 692 + advanced/overriding-queues/index.html | 641 + advanced/setting-priorities/index.html | 597 + advanced/stateful-tasks/index.html | 734 + api-reference/index.html | 16318 ++++++++++++++++ assets/_mkdocstrings.css | 16 + assets/images/favicon.png | Bin 0 -> 1870 bytes assets/javascripts/bundle.a6c66575.min.js | 29 + assets/javascripts/bundle.a6c66575.min.js.map | 8 + assets/javascripts/lunr/min/lunr.ar.min.js | 1 + assets/javascripts/lunr/min/lunr.da.min.js | 18 + assets/javascripts/lunr/min/lunr.de.min.js | 18 + assets/javascripts/lunr/min/lunr.du.min.js | 18 + assets/javascripts/lunr/min/lunr.es.min.js | 18 + assets/javascripts/lunr/min/lunr.fi.min.js | 18 + assets/javascripts/lunr/min/lunr.fr.min.js | 18 + assets/javascripts/lunr/min/lunr.hi.min.js | 1 + assets/javascripts/lunr/min/lunr.hu.min.js | 18 + assets/javascripts/lunr/min/lunr.it.min.js | 18 + assets/javascripts/lunr/min/lunr.ja.min.js | 1 + assets/javascripts/lunr/min/lunr.jp.min.js | 1 + assets/javascripts/lunr/min/lunr.multi.min.js | 1 + assets/javascripts/lunr/min/lunr.nl.min.js | 18 + assets/javascripts/lunr/min/lunr.no.min.js | 18 + assets/javascripts/lunr/min/lunr.pt.min.js | 18 + assets/javascripts/lunr/min/lunr.ro.min.js | 18 + assets/javascripts/lunr/min/lunr.ru.min.js | 18 + .../lunr/min/lunr.stemmer.support.min.js | 1 + assets/javascripts/lunr/min/lunr.sv.min.js | 18 + assets/javascripts/lunr/min/lunr.th.min.js | 1 + assets/javascripts/lunr/min/lunr.tr.min.js | 18 + assets/javascripts/lunr/min/lunr.vi.min.js | 1 + assets/javascripts/lunr/min/lunr.zh.min.js | 1 + assets/javascripts/lunr/tinyseg.js | 206 + assets/javascripts/lunr/wordcut.js | 6708 +++++++ .../workers/search.2a1c317c.min.js | 48 + .../workers/search.2a1c317c.min.js.map | 8 + assets/stylesheets/main.c382b1dc.min.css | 1 + assets/stylesheets/main.c382b1dc.min.css.map | 1 + assets/stylesheets/palette.cc9b2e1e.min.css | 1 + .../stylesheets/palette.cc9b2e1e.min.css.map | 1 + best-practices/index.html | 676 + cli-reference/index.html | 672 + common-operations/index.html | 683 + comparison-to-celery/index.html | 679 + examples/index.html | 634 + index.html | 775 + objects.inv | Bin 0 -> 1653 bytes search/search_index.json | 1 + sitemap.xml | 78 + sitemap.xml.gz | Bin 0 -> 207 bytes vendor-specifics/rabbitmq/index.html | 730 + vendor-specifics/redis/index.html | 741 + 57 files changed, 34053 insertions(+) create mode 100644 .nojekyll create mode 100644 404.html create mode 100644 advanced/dependency-injection/index.html create mode 100644 advanced/metadata/index.html create mode 100644 advanced/middleware/index.html create mode 100644 advanced/overriding-queues/index.html create mode 100644 advanced/setting-priorities/index.html create mode 100644 advanced/stateful-tasks/index.html create mode 100644 api-reference/index.html create mode 100644 assets/_mkdocstrings.css create mode 100644 assets/images/favicon.png create mode 100644 assets/javascripts/bundle.a6c66575.min.js create mode 100644 assets/javascripts/bundle.a6c66575.min.js.map create mode 100644 assets/javascripts/lunr/min/lunr.ar.min.js create mode 100644 assets/javascripts/lunr/min/lunr.da.min.js create mode 100644 assets/javascripts/lunr/min/lunr.de.min.js create mode 100644 assets/javascripts/lunr/min/lunr.du.min.js create mode 100644 assets/javascripts/lunr/min/lunr.es.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.fr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.hu.min.js create mode 100644 assets/javascripts/lunr/min/lunr.it.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ja.min.js create mode 100644 assets/javascripts/lunr/min/lunr.jp.min.js create mode 100644 assets/javascripts/lunr/min/lunr.multi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.nl.min.js create mode 100644 assets/javascripts/lunr/min/lunr.no.min.js create mode 100644 assets/javascripts/lunr/min/lunr.pt.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ro.min.js create mode 100644 assets/javascripts/lunr/min/lunr.ru.min.js create mode 100644 assets/javascripts/lunr/min/lunr.stemmer.support.min.js create mode 100644 assets/javascripts/lunr/min/lunr.sv.min.js create mode 100644 assets/javascripts/lunr/min/lunr.th.min.js create mode 100644 assets/javascripts/lunr/min/lunr.tr.min.js create mode 100644 assets/javascripts/lunr/min/lunr.vi.min.js create mode 100644 assets/javascripts/lunr/min/lunr.zh.min.js create mode 100644 assets/javascripts/lunr/tinyseg.js create mode 100644 assets/javascripts/lunr/wordcut.js create mode 100644 assets/javascripts/workers/search.2a1c317c.min.js create mode 100644 assets/javascripts/workers/search.2a1c317c.min.js.map create mode 100644 assets/stylesheets/main.c382b1dc.min.css create mode 100644 assets/stylesheets/main.c382b1dc.min.css.map create mode 100644 assets/stylesheets/palette.cc9b2e1e.min.css create mode 100644 assets/stylesheets/palette.cc9b2e1e.min.css.map create mode 100644 best-practices/index.html create mode 100644 cli-reference/index.html create mode 100644 common-operations/index.html create mode 100644 comparison-to-celery/index.html create mode 100644 examples/index.html create mode 100644 index.html create mode 100644 objects.inv create mode 100644 search/search_index.json create mode 100644 sitemap.xml create mode 100644 sitemap.xml.gz create mode 100644 vendor-specifics/rabbitmq/index.html create mode 100644 vendor-specifics/redis/index.html diff --git a/.nojekyll b/.nojekyll new file mode 100644 index 0000000..e69de29 diff --git a/404.html b/404.html new file mode 100644 index 0000000..51bc3e5 --- /dev/null +++ b/404.html @@ -0,0 +1,529 @@ + + + +
+ + + + + + + + + + + +For decoupling your app's tasks from their dependencies, you can use Services. These are functions that return a value.
+To create a service, declare a function that receives a Context
, an optional list of arguments, and returns a value.
For example, you can use a Service to create a temporary directory for a task:
+import tempfile
+from mognet import Context
+from pathlib import Path
+
+def temp_dir(context: Context) -> Path:
+ task_name = context.request.name
+ task_id = str(context.request.id)
+
+ return Path(tempfile.mkdtemp(prefix=task_name, suffix=task_id))
+
To use a service in a task, use get_service
as follows:
from mognet import Context, task
+from example.services import temp_dir
+
+@task(name="test.use_temp_dir")
+async def use_temp_dir(context: Context):
+
+ my_temp_dir = context.get_service(temp_dir)
+
+ # Use my_temp_dir
+
The result of the function call is not stored, meaning that every time you call get_service
, you will get a new temporary directory.
To create a Service that accepts parameters, add the parameters to the Service function, and pass the values via the get_service
call.
from mognet import Context, task
+
+class Counter:
+ def __init__(self, n: int):
+ self.n = n
+
+ def increment(self, n: int):
+ self.n += n
+
+def counter(context: Context, start: int):
+ return Counter(start)
+
+@task(name="example.use_counter")
+async def use_counter(context: Context):
+ my_counter = context.get_service(counter, 5)
+
+ # my_counter is a Counter that starts with 5
+ counter.increment(1)
+
+ assert counter.n == 6
+
Classes can be used as services, too, provided they extend the ClassService
class.
Class Services are different from their function counterparts, because:
+They have access to __enter__
and __exit__
methods for setup and teardown:
Initialization, unless done explicitly (see :ref:overriding-a-service
), is lazy
Tear down is done at app shutdown
+They act as factories, and the __call__
method must be overriden in order to return the value.
get_service
method is called with the class. Argument passing is still allowed.Class Services are ideal for managed, long-lived resources, such as database connections.
+Some services require some asyncio-based setup. Their functions can be async def
(coroutines), however, since get_service
is sync, your app's code must await
the returned coroutine itself.
It's on the roadmap 😉
+To override Services, you can use the services
dictionary on the App
class. The keys are the functions/classes that represent your services,
+and the values are callables (i.e., either a function, an object with a __call__
or a ClassService
instance.
Let's assume we want to override the counter
Service we created previously. To do it, we would do the following:
from mognet import Context
+
+# Get the reference to the original service
+from example.services import counter, MyClassService
+
+# Assume that this has the same interface
+# as Counter
+from counter_lib import NoDecrementCounter
+
+app = App(...)
+
+def different_counter(context: Context, n: int):
+ return NoDecrementCounter(n)
+
+app.services[counter] = different_counter
+
This effectively redirects the call to a different function, allowing for decoupling your app's components. You can use this technique in your unit tests, in order to inject different objects into your tasks for testing purposes.
+ + +You can use metadata to store additional information on a Request
object. Mognet itself makes no use of it, but it can be used both by the tasks and by Middleware.
Metadata can be useful to store information like:
+Metadata is defined as a dict[str, Any]
, and it is stored as JSON in the respective Result
.
To set metadata, you use the metadata
field. As described before, this is a dict[str, Any]
, so you can do it several ways:
from mognet import App, Request
+
+app = App(...)
+
+# Approach 1: Create a Request object manually and set the metadata
+# on the constructor
+req1 = Request(
+ ...,
+ metadata={
+ "user_id": "val",
+ },
+)
+
+# Approach 2: Create a Request object, for example,
+# with the [`create_request()`][mognet.App.create_request] method,
+# and set the metadata field afterwards
+req2 = app.create_request(...)
+req2.metadata["user_id"] = "cau"
+
Note that:
+Result
In case you need to get metadata in a running task, you can do so through the context
, like this:
from mognet import App, Context, task
+
+app = App(...)
+
+@task(name="demo.get_metadata")
+def demo_get_metadata(context: Context):
+ user_id = context.request.metadata["user_id"]
+
+ print(f"This task was launched by @{user_id}")
+
+
+async def main():
+ req = app.create_request(demo_get_metadata)
+ req.metadata["user_id"] = "dol"
+
+ await app.run(req)
+
In this hypothetical example, the worker would print This task was launched by @dol
.
You can set the metadata in a running task by using the set_metadata()
method on the context
:
@task(name="demo.set_metadata")
+async def demo_set_metadata(context: Context):
+ await context.set_metadata(user_id="dkl")
+
Note that this method is asynchronous, because it is storing the values on the Result Backend (Redis). Also mind that the associated Request object does not get updated accordingly (it is anyway destroyed after the task completes).
+To get metadata outside of a task, you get it's associated Result
and then call get_metadata()
on it:
Middleware can be used into hook into several points of the Mognet application, like, for example:
+It is therefore a good candidate for tasks like:
+To create a Middleware, first you need to create a class that implements the Middleware
class. This class, by default, does nothing.
from mognet import Middleware
+
+class MyMiddleware(Middleware):
+ async def on_running_task_count_changed(self, running_task_count: int):
+ print(f"We currently have {running_task_count} tasks running!")
+
To add middleware to a Mognet Worker's app, you use the add_middleware()
function, ideally after creating the App
object:
See AutoShutdownMiddleware
. This class implements a common pattern for long running applications: periodic restarts to release memory, due to Python's memory model.
Sometimes you want to have different tasks routed to different queues. To do so, you can use the following properties on the AppConfig
class:
task_routes
: This allows to set the queue name for each task. Tasks which are not set here default to the default queue name (tasks
)task_queues
: This allows you to only listen on a specific set of queuesCombining the two allows you to have workers that respond to a specific set of task types. This is useful in case you have tasks that require specific resources (such as GPUs, storage, CPU, etc.), or because you want these tasks to be processable even if other queues are busy.
+You can override the default task queue name (tasks
) by setting the field default_task_route
on the AppConfig
.
The RabbitMQ broker configures it's queues with priority support. This means that you can send Request
objects with priorities set.
You can do so by setting the priority
field, as an integer, from 0 to 10. The default is 5, and higher values mean higher priority.
Bear in mind that messing with priorities may starve other tasks; if there are too many high-priority tasks running, they may heavily delay lower-priority tasks from running.
+ + +You can have mechanisms similar to Azure's Durable Functions by making use of State
.
State acts like a dict[str, Any]
, whose values must be JSON-serializable, and you access it through your task's context
argument, and you can use it to store information that gets persisted across Worker reboots.
Some cases where it can be useful include:
+from typing import List
+
+from mognet import task, Context
+
+@task(name="demo.state")
+async def use_state(context: Context, files: List[str]):
+ # Let's assume that `files` is a list of file names, and each takes a long time to process.
+ # We can use state to store what was the last file we were working on.
+ #
+ # Then, should the task restart, we can instead continue from there.
+
+ # The first time this function gets called, "current_index" won't exist.
+ # So, we put a default value of 0 (otherwise we would get None).
+ last_processed_index: int = await context.state.get("current_index", 0)
+
+ for i, file in enumerate(files):
+ # Skip files already processed...
+ if i < last_processed_index:
+ continue
+
+ await process_file(file) # Implementation left to the reader...
+
+ # Create a check point here in case the Worker reboots.
+ await context.state.set(current_index=i)
+
State associated with a task is stored on Redis, hence the asynchronous interface. All values are stored as JSON, meaning that you will need to (de)serialize complex values (such as BaseModel
classes from Pydantic) yourself.
All state has a TTL, you can set the default TTL in the RedisStateBackendSettings
class when you configure your Mognet App instance.
After your task's function is done (either because it finished successfully, failed, or got revoked), it's state is automatically cleared.
+Task functions can also be paused (or rather, stopped, and them restarted).
+Let's take the previous example, and assume that we don't want to process more than 5 files at a time, and after that, we want to stop processing files for a while. Let's assume that we want to do this because each file takes a long time to process, and it's processing cannot be done via subtasks.
+For that, we can use the Pause
exception. Raising this exception from a task function will:
SUSPENDED
on the Result BackendRequest
message back to the Task BrokerOnce the message is returned to the queue, it will eventually be picked up again, situation where the task will restart. Combining this with state allows the function to resume from where it left off
+from typing import List
+
+from mognet import task, Context
+from mognet.exceptions.task_exceptions import Pause
+
+@task(name="demo.state")
+async def use_state(context: Context, files: List[str]):
+ # Let's assume that `files` is a list of file names, and each takes a long time to process.
+ # We can use state to store what was the last file we were working on.
+ #
+ # Then, should the task restart, we can instead continue from there.
+
+ # The first time this function gets called, "current_index" won't exist.
+ # So, we put a default value of 0 (otherwise we would get None).
+ last_processed_index: int = await context.state.get("current_index", 0)
+
+ processed_file_count = 0
+
+ for i, file in enumerate(files):
+ # Skip files already processed...
+ if i < last_processed_index:
+ continue
+
+ await process_file(file) # Implementation left to the reader...
+
+ # Create a check point here in case the Worker reboots.
+ await context.state.set(current_index=i)
+
+ processed_file_count += 1
+
+ if processed_file_count == 5:
+ # Here, the task function will stop.
+ #
+ # Make sure you don't catch this exception yourself!
+ #
+ # When the Request for this task is picked up again, the state will allow it
+ # to know where to resume from.
+ raise Pause()
+
app
+
+
+
+ special
+
+
+app
+
+
+
+
+App
+
+
+
+Represents the Mognet application.
+You can use these objects to:
+mognet/app/app.py
class App:
+ """
+ Represents the Mognet application.
+
+ You can use these objects to:
+
+ - Create and abort tasks
+ - Check the status of tasks
+ - Configure the middleware that runs on key lifecycle events of the app and its tasks
+ """
+
+ # Where results are stored.
+ result_backend: BaseResultBackend
+
+ # Where task state is stored.
+ # Task state is information that a task can save
+ # during its execution, in case, for example, it gets
+ # interrupted.
+ state_backend: BaseStateBackend
+
+ # Task broker.
+ broker: BaseBroker
+
+ # Mapping of [service name] -> dependency object,
+ # should be accessed via Context#get_service.
+ services: Dict[Any, Callable]
+
+ # Holds references to all the tasks.
+ task_registry: TaskRegistry
+
+ _connected: bool
+
+ # Configuration used to start this app.
+ config: "AppConfig"
+
+ # Worker running in this app instance.
+ worker: Optional[Worker]
+
+ # Background tasks spawned by this app.
+ _consume_control_task: Optional[Future] = None
+ _heartbeat_task: Optional[Future] = None
+
+ _worker_task: Optional[Future]
+
+ _middleware: List[Middleware]
+
+ _loop: asyncio.AbstractEventLoop
+
+ def __init__(
+ self,
+ *,
+ name: str,
+ config: "AppConfig",
+ ) -> None:
+ self.name = name
+
+ self._connected = False
+
+ self.config = config
+
+ # Create the task registry and register it globally
+ reg = task_registry.get(None)
+
+ if reg is None:
+ reg = TaskRegistry()
+ reg.register_globally()
+
+ self.task_registry = reg
+
+ self._worker_task = None
+
+ self.services = {}
+
+ self._middleware = []
+
+ self._load_modules()
+ self.worker = None
+
+ # Event that gets set when the app is closed
+ self._run_result = None
+
+ def add_middleware(self, mw_inst: Middleware):
+ """
+ Adds middleware to this app.
+
+ Middleware is called in the order of in which it was added
+ to the app.
+ """
+ if mw_inst in self._middleware:
+ return
+
+ self._middleware.append(mw_inst)
+
+ async def start(self):
+ """
+ Starts the app.
+ """
+ _log.info("Starting app %r", self.config.node_id)
+
+ self._loop = asyncio.get_event_loop()
+
+ self._run_result = asyncio.Future()
+
+ self._log_tasks_and_queues()
+
+ await self._call_on_starting_middleware()
+
+ await self.connect()
+
+ self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
+ self._consume_control_task = asyncio.create_task(self._consume_control_queue())
+
+ self.worker = Worker(app=self, middleware=self._middleware)
+ self._worker_task = asyncio.create_task(self.worker.run())
+
+ _log.info("Started")
+
+ await self._call_on_started_middleware()
+
+ return await self._run_result
+
+ async def get_current_status_of_nodes(
+ self,
+ ) -> AsyncGenerator[StatusResponseMessage, None]:
+ """
+ Query all nodes of this App and get their status.
+ """
+
+ request = QueryRequestMessage(name="Status")
+
+ responses = self.broker.send_query_message(
+ payload=MessagePayload(
+ id=str(request.id),
+ kind="Query",
+ payload=request,
+ )
+ )
+
+ try:
+ async for response in responses:
+ try:
+ yield StatusResponseMessage.parse_obj(response)
+ except asyncio.CancelledError:
+ break
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Could not parse status response %r", response, exc_info=exc
+ )
+ finally:
+ await responses.aclose()
+
+ async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
+ """
+ Submits a request for execution.
+
+ If a context is defined, it will be used to create a parent-child
+ relationship between the new-to-submit request and the one existing
+ in the context instance. This is later used to cancel the whole task tree.
+ """
+ if not self.result_backend:
+ raise ImproperlyConfigured("Result backend not defined")
+
+ if not self.broker:
+ raise ImproperlyConfigured("Broker not connected")
+
+ try:
+ if req.kwargs_repr is None:
+ req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
+ _log.debug("Set default kwargs_repr on Request %r", req)
+
+ res = Result(
+ self.result_backend,
+ id=req.id,
+ name=req.name,
+ state=ResultState.PENDING,
+ created=datetime.now(tz=timezone.utc),
+ request_kwargs_repr=req.kwargs_repr,
+ )
+
+ if context is not None:
+ # Set the parent-child relationship and update the request stack.
+ parent_request = context.request
+ res.parent_id = parent_request.id
+
+ req.stack = [*parent_request.stack, parent_request.id]
+
+ if res.parent_id is not None:
+ await self.result_backend.add_children(res.parent_id, req.id)
+
+ await self.result_backend.set(req.id, res)
+
+ # Store the metadata on the Result.
+ if req.metadata:
+ await res.set_metadata(**req.metadata)
+
+ await self._on_submitting(req, context=context)
+
+ payload = MessagePayload(
+ id=str(req.id),
+ kind="Request",
+ payload=req,
+ priority=req.priority,
+ )
+
+ _log.debug("Sending message %r", payload.id)
+
+ await self.broker.send_task_message(self._get_task_route(req), payload)
+
+ return res
+ except Exception as exc:
+ raise CouldNotSubmit(f"Could not submit {req!r}") from exc
+
+ def get_task_queue_names(self) -> Set[str]:
+ """
+ Return the names of the queues that are going to be consumed,
+ after applying defaults, inclusions, and exclusions.
+ """
+ all_queues = {*self.config.task_routes.values(), self.config.default_task_route}
+
+ _log.debug("All queues: %r", all_queues)
+
+ configured_queues = self.config.task_queues
+
+ configured_queues.ensure_valid()
+
+ if configured_queues.exclude:
+ _log.debug("Applying queue exclusions: %r", configured_queues.exclude)
+ return all_queues - configured_queues.exclude
+
+ if configured_queues.include:
+ _log.debug("Applying queue inclusions: %r", configured_queues.include)
+ return all_queues & configured_queues.include
+
+ _log.debug("No inclusions or exclusions applied")
+
+ return all_queues
+
+ @overload
+ def create_request(
+ self,
+ func: Callable[Concatenate["Context", _P], Awaitable[_Return]],
+ *args: _P.args,
+ **kwargs: _P.kwargs,
+ ) -> Request[_Return]:
+ """
+ Creates a Request object from the function that was decorated with @task,
+ and the provided arguments.
+
+ This overload is just to document async def function return values.
+ """
+ ...
+
+ @overload
+ def create_request(
+ self,
+ func: Callable[Concatenate["Context", _P], _Return],
+ *args: _P.args,
+ **kwargs: _P.kwargs,
+ ) -> Request[_Return]:
+ """
+ Creates a Request object from the function that was decorated with @task,
+ and the provided arguments.
+
+ This overload is just to document non-async def function return values.
+ """
+ ...
+
+ def create_request(
+ self,
+ func: Callable[Concatenate["Context", _P], Any],
+ *args: _P.args,
+ **kwargs: _P.kwargs,
+ ) -> Request:
+ """
+ Creates a Request object from the function that was decorated with @task,
+ and the provided arguments.
+ """
+ return Request(
+ name=self.task_registry.get_task_name(cast(Any, func)),
+ args=args,
+ kwargs=kwargs,
+ )
+
+ @overload
+ async def run(
+ self,
+ request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
+ *args: _P.args,
+ **kwargs: _P.kwargs,
+ ) -> _Return:
+ """
+ Short-hand method for creating a Request from a function decorated with `@task`,
+ (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
+ """
+ ...
+
+ @overload
+ async def run(
+ self, request: "Request[_Return]", context: Optional[Context] = None
+ ) -> _Return:
+ """
+ Runs the request and waits for the result.
+
+ Call `submit` if you just want to send a request
+ without waiting for the result.
+ """
+
+ ...
+
+ async def run(self, request, *args, **kwargs) -> Any:
+
+ if not isinstance(request, Request):
+ request = self.create_request(*args, **kwargs)
+
+ res = await self.submit(request, *args, **kwargs)
+
+ return await res
+
+ async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
+ """
+ Revoke the execution of a request.
+
+ If the request is already completed, this method returns
+ the associated result as-is. Optionally, `force=True` may be set
+ in order to ignore the state check.
+
+ This will also revoke any request that's launched as a child of this one,
+ recursively.
+
+ Returns the cancelled result.
+ """
+ res = await self.result_backend.get_or_create(request_id)
+
+ if not force and res.done:
+ _log.warning(
+ "Attempting to cancel result %r that's already done, this is a no-op",
+ res.id,
+ )
+ return res
+
+ _log.info("Revoking request id=%r", res)
+
+ await res.revoke()
+
+ payload = MessagePayload(
+ id=str(uuid.uuid4()),
+ kind=Revoke.MESSAGE_KIND,
+ payload=Revoke(id=request_id),
+ )
+
+ await self.broker.send_control_message(payload)
+
+ child_count = await res.children.count()
+ if child_count:
+ _log.info("Revoking %r children of id=%r", child_count, res.id)
+
+ # Abort children.
+ async for child_id in res.children.iter_ids():
+ await self.revoke(child_id, force=force)
+
+ return res
+
+ async def connect(self):
+ """Connect this app and its components to their respective backends."""
+ if self._connected:
+ return
+
+ self.broker = self._create_broker()
+
+ self.result_backend = self._create_result_backend()
+ self.state_backend = self._create_state_backend()
+
+ self._connected = True
+
+ await self._setup_broker()
+
+ _log.debug("Connecting to result backend %s", self.result_backend)
+
+ await self.result_backend.connect()
+
+ _log.debug("Connected to result backend %s", self.result_backend)
+
+ _log.debug("Connecting to state backend %s", self.state_backend)
+
+ await self.state_backend.connect()
+
+ _log.debug("Connected to state backend %s", self.state_backend)
+
+ async def __aenter__(self):
+ await self.connect()
+
+ return self
+
+ async def __aexit__(self, *args, **kwargs):
+ await self.close()
+
+ @shield
+ async def close(self):
+ """Close this app and its components's backends."""
+
+ _log.info("Closing app")
+
+ await asyncio.shield(self._stop())
+
+ if self._run_result and not self._run_result.done():
+ self._run_result.set_result(None)
+
+ _log.info("Closed app")
+
+ async def _stop(self):
+ await self._call_on_stopping_middleware()
+
+ if self._heartbeat_task is not None:
+ self._heartbeat_task.cancel()
+
+ try:
+ await self._heartbeat_task
+ except BaseException: # pylint: disable=broad-except
+ pass
+
+ self._heartbeat_task = None
+
+ _log.debug("Closing queue listeners")
+
+ if self._consume_control_task:
+ self._consume_control_task.cancel()
+
+ try:
+ await self._consume_control_task
+ except asyncio.CancelledError:
+ pass
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug("Error shutting down control consumption task", exc_info=exc)
+
+ self._consume_control_task = None
+
+ # Disconnect from the broker, this should NACK
+ # all pending messages too.
+ _log.debug("Closing broker connection")
+ if self.broker:
+ await self.broker.close()
+
+ # Stop the worker
+ await self._stop_worker()
+
+ # Remove service instances
+ for svc in self.services:
+ if isinstance(svc, ClassService):
+ try:
+ svc.close()
+ await svc.wait_closed()
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error("Error closing service %r", svc, exc_info=exc)
+
+ self.services.clear()
+
+ # Finally, shut down the state and result backends.
+ _log.debug("Closing backends")
+ if self.result_backend:
+ try:
+ await self.result_backend.close()
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error("Error closing result backend", exc_info=exc)
+
+ if self.state_backend:
+ try:
+ await self.state_backend.close()
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error("Error closing state backend", exc_info=exc)
+
+ self._connected = False
+
+ await self._call_on_stopped_middleware()
+
+ async def purge_task_queues(self) -> Dict[str, int]:
+ """
+ Purge all known task queues.
+
+ Returns a dict where the keys are the names of the queues,
+ and the values are the number of messages purged.
+ """
+ deleted_per_queue = {}
+
+ for queue in self.get_task_queue_names():
+ _log.info("Purging task queue=%r", queue)
+ deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)
+
+ return deleted_per_queue
+
+ async def purge_control_queue(self) -> int:
+ """
+ Purges the control queue related to this app.
+
+ Returns the number of messages purged.
+ """
+ return await self.broker.purge_control_queue()
+
+ @property
+ def loop(self) -> asyncio.AbstractEventLoop:
+ return self._loop
+
+ def _create_broker(self) -> BaseBroker:
+ return AmqpBroker(config=self.config.broker, app=self)
+
+ def _create_result_backend(self) -> BaseResultBackend:
+ return RedisResultBackend(self.config.result_backend, app=self)
+
+ def _create_state_backend(self) -> BaseStateBackend:
+ return RedisStateBackend(self.config.state_backend, app=self)
+
+ def _load_modules(self):
+ for module in self.config.imports:
+ importlib.import_module(module)
+
+ def _log_tasks_and_queues(self):
+
+ all_tasks = self.task_registry.registered_task_names
+
+ tasks_msg = "\n".join(
+ f"\t - {t!r} (queue={self._get_task_route(t)!r})" for t in all_tasks
+ )
+
+ _log.info("Registered %r tasks:\n%s", len(all_tasks), tasks_msg)
+
+ all_queues = self.get_task_queue_names()
+
+ queues_msg = "\n".join(f"\t - {q!r}" for q in all_queues)
+
+ _log.info("Registered %r queues:\n%s", len(all_queues), queues_msg)
+
+ async def _setup_broker(self):
+ _log.debug("Connecting to broker %s", self.broker)
+
+ await self.broker.connect()
+
+ _log.debug("Connected to broker %r", self.broker)
+
+ _log.debug("Setting up task queues")
+
+ for queue_name in self.get_task_queue_names():
+ await self.broker.setup_task_queue(TaskQueue(name=queue_name))
+
+ _log.debug("Setup queues")
+
+ async def _stop_worker(self):
+ if self.worker is None or not self._worker_task:
+ _log.debug("No worker running")
+ return
+
+ try:
+ _log.debug("Closing worker")
+ await self.worker.close()
+
+ if self._worker_task is not None:
+ self._worker_task.cancel()
+ await self._worker_task
+
+ _log.debug("Worker closed")
+ except asyncio.CancelledError:
+ pass
+ except Exception as worker_close_exc: # pylint: disable=broad-except
+ _log.error(
+ "Worker raised an exception while closing", exc_info=worker_close_exc
+ )
+ finally:
+ self.worker = None
+ self._worker_task = None
+
+ async def _background_heartbeat(self):
+ """
+ Background task that checks if the event loop was blocked
+ for too long.
+
+ A crude check, it asyncio.sleep()s and checks if the time difference
+ before and after sleeping is significantly higher. This could bring problems,
+ for example, with task brokers, that may need to send periodic keep-alive messages
+ to the broker in order to prevent connection drops.
+
+ Error messages are logged in case the event loop got blocked for too long.
+ """
+
+ while True:
+ current_ts = self.loop.time()
+
+ await asyncio.sleep(5)
+
+ next_ts = self.loop.time()
+
+ diff = next_ts - current_ts
+
+ if diff > 10:
+ _log.error(
+ "Event loop seemed blocked for %.2fs (>10s), this could bring issues. Consider using asyncio.run_in_executor to run CPU-bound work",
+ diff,
+ )
+ else:
+ _log.debug("Event loop heartbeat: %.2fs", diff)
+
+ async def _consume_control_queue(self):
+ """
+ Reads messages from the control queue and dispatches them.
+ """
+
+ await self.broker.setup_control_queue()
+
+ _log.debug("Listening on the control queue")
+
+ async for msg in self.broker.consume_control_queue():
+ try:
+ await self._process_control_message(msg)
+ except asyncio.CancelledError:
+ break
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Could not process control queue message %r", msg, exc_info=exc
+ )
+
+ async def _process_control_message(self, msg: IncomingMessagePayload):
+ _log.debug("Received control message id=%r", msg.id)
+
+ try:
+ if msg.kind == Revoke.MESSAGE_KIND:
+ abort = Revoke.parse_obj(msg.payload)
+
+ _log.debug("Received request to revoke request id=%r", abort.id)
+
+ if self.worker is None:
+ _log.debug("No worker running. Discarding revoke message.")
+ return
+
+ try:
+ # Cancel the task's execution and ACK it on the broker
+ # to prevent it from re-running.
+ await self.worker.cancel(
+ abort.id, message_action=MessageCancellationAction.ACK
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Error while cancelling request id=%r", abort.id, exc_info=exc
+ )
+
+ return
+
+ if msg.kind == "Query":
+ query = QueryRequestMessage.parse_obj(msg.payload)
+
+ if query.name == "Status":
+ # Get the status of this worker and reply to the incoming message
+
+ if self.worker is None:
+ _log.debug("No worker running for Status query")
+ running_request_ids = []
+ else:
+ running_request_ids = list(self.worker.running_tasks.keys())
+
+ reply = StatusResponseMessage(
+ node_id=self.config.node_id,
+ payload=StatusResponseMessage.Status(
+ running_request_ids=running_request_ids,
+ ),
+ )
+
+ payload = MessagePayload(
+ id=str(reply.id), kind=reply.kind, payload=reply
+ )
+
+ return await self.broker.send_reply(msg, payload)
+
+ _log.warning("Unknown query name=%r, discarding", query.name)
+ return
+
+ _log.warning("Unknown message kind=%r, discarding", msg.kind)
+ finally:
+ await msg.ack()
+
+ async def _on_submitting(self, req: "Request", context: Optional["Context"]):
+ for mw_inst in self._middleware:
+ try:
+ await mw_inst.on_request_submitting(req, context=context)
+ except Exception as mw_exc: # pylint: disable=broad-except
+ _log.error("Middleware failed", exc_info=mw_exc)
+
+ def _get_task_route(self, req: Union[str, Request]):
+ if isinstance(req, Request):
+ if req.queue_name is not None:
+ _log.debug(
+ "Request %r has a queue override to route to queue=%r",
+ req,
+ req.queue_name,
+ )
+ return req.queue_name
+
+ req = req.name
+
+ route = self.config.task_routes.get(req)
+
+ if route is not None:
+ _log.debug(
+ "Request %r has a config-set route to queue=%r",
+ req,
+ route,
+ )
+ return route
+
+ default_queue = self.config.default_task_route
+
+ _log.debug(
+ "Request %r has no route set, falling back to default queue=%r",
+ req,
+ default_queue,
+ )
+
+ return default_queue
+
+ async def _call_on_starting_middleware(self):
+ for mw in self._middleware:
+ try:
+ await mw.on_app_starting(self)
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug(
+ "Middleware %r failed on 'on_app_starting'", mw, exc_info=exc
+ )
+
+ async def _call_on_started_middleware(self):
+ for mw in self._middleware:
+ try:
+ await mw.on_app_started(self)
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug("Middleware %r failed on 'on_app_started'", mw, exc_info=exc)
+
+ async def _call_on_stopping_middleware(self):
+ for mw in self._middleware:
+ try:
+ await mw.on_app_stopping(self)
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug(
+ "Middleware %r failed on 'on_app_stopping'", mw, exc_info=exc
+ )
+
+ async def _call_on_stopped_middleware(self):
+ for mw in self._middleware:
+ try:
+ await mw.on_app_stopped(self)
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug("Middleware %r failed on 'on_app_stopped'", mw, exc_info=exc)
+
add_middleware(self, mw_inst)
+
+
+Adds middleware to this app.
+Middleware is called in the order of in which it was added +to the app.
+ + +close(self)
+
+
+ async
+
+
+Close this app and its components's backends.
+ +mognet/app/app.py
connect(self)
+
+
+ async
+
+
+Connect this app and its components to their respective backends.
+ +mognet/app/app.py
async def connect(self):
+ """Connect this app and its components to their respective backends."""
+ if self._connected:
+ return
+
+ self.broker = self._create_broker()
+
+ self.result_backend = self._create_result_backend()
+ self.state_backend = self._create_state_backend()
+
+ self._connected = True
+
+ await self._setup_broker()
+
+ _log.debug("Connecting to result backend %s", self.result_backend)
+
+ await self.result_backend.connect()
+
+ _log.debug("Connected to result backend %s", self.result_backend)
+
+ _log.debug("Connecting to state backend %s", self.state_backend)
+
+ await self.state_backend.connect()
+
+ _log.debug("Connected to state backend %s", self.state_backend)
+
create_request(self, func, *args, **kwargs)
+
+
+Creates a Request object from the function that was decorated with @task, +and the provided arguments.
+ +mognet/app/app.py
def create_request(
+ self,
+ func: Callable[Concatenate["Context", _P], Any],
+ *args: _P.args,
+ **kwargs: _P.kwargs,
+) -> Request:
+ """
+ Creates a Request object from the function that was decorated with @task,
+ and the provided arguments.
+ """
+ return Request(
+ name=self.task_registry.get_task_name(cast(Any, func)),
+ args=args,
+ kwargs=kwargs,
+ )
+
get_current_status_of_nodes(self)
+
+
+Query all nodes of this App and get their status.
+ +mognet/app/app.py
async def get_current_status_of_nodes(
+ self,
+) -> AsyncGenerator[StatusResponseMessage, None]:
+ """
+ Query all nodes of this App and get their status.
+ """
+
+ request = QueryRequestMessage(name="Status")
+
+ responses = self.broker.send_query_message(
+ payload=MessagePayload(
+ id=str(request.id),
+ kind="Query",
+ payload=request,
+ )
+ )
+
+ try:
+ async for response in responses:
+ try:
+ yield StatusResponseMessage.parse_obj(response)
+ except asyncio.CancelledError:
+ break
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Could not parse status response %r", response, exc_info=exc
+ )
+ finally:
+ await responses.aclose()
+
get_task_queue_names(self)
+
+
+Return the names of the queues that are going to be consumed, +after applying defaults, inclusions, and exclusions.
+ +mognet/app/app.py
def get_task_queue_names(self) -> Set[str]:
+ """
+ Return the names of the queues that are going to be consumed,
+ after applying defaults, inclusions, and exclusions.
+ """
+ all_queues = {*self.config.task_routes.values(), self.config.default_task_route}
+
+ _log.debug("All queues: %r", all_queues)
+
+ configured_queues = self.config.task_queues
+
+ configured_queues.ensure_valid()
+
+ if configured_queues.exclude:
+ _log.debug("Applying queue exclusions: %r", configured_queues.exclude)
+ return all_queues - configured_queues.exclude
+
+ if configured_queues.include:
+ _log.debug("Applying queue inclusions: %r", configured_queues.include)
+ return all_queues & configured_queues.include
+
+ _log.debug("No inclusions or exclusions applied")
+
+ return all_queues
+
purge_control_queue(self)
+
+
+ async
+
+
+Purges the control queue related to this app.
+Returns the number of messages purged.
+ + +purge_task_queues(self)
+
+
+ async
+
+
+Purge all known task queues.
+Returns a dict where the keys are the names of the queues, +and the values are the number of messages purged.
+ +mognet/app/app.py
async def purge_task_queues(self) -> Dict[str, int]:
+ """
+ Purge all known task queues.
+
+ Returns a dict where the keys are the names of the queues,
+ and the values are the number of messages purged.
+ """
+ deleted_per_queue = {}
+
+ for queue in self.get_task_queue_names():
+ _log.info("Purging task queue=%r", queue)
+ deleted_per_queue[queue] = await self.broker.purge_task_queue(queue)
+
+ return deleted_per_queue
+
revoke(self, request_id, *, force=False)
+
+
+ async
+
+
+Revoke the execution of a request.
+If the request is already completed, this method returns
+the associated result as-is. Optionally, force=True
may be set
+in order to ignore the state check.
This will also revoke any request that's launched as a child of this one, +recursively.
+Returns the cancelled result.
+ +mognet/app/app.py
async def revoke(self, request_id: uuid.UUID, *, force: bool = False) -> Result:
+ """
+ Revoke the execution of a request.
+
+ If the request is already completed, this method returns
+ the associated result as-is. Optionally, `force=True` may be set
+ in order to ignore the state check.
+
+ This will also revoke any request that's launched as a child of this one,
+ recursively.
+
+ Returns the cancelled result.
+ """
+ res = await self.result_backend.get_or_create(request_id)
+
+ if not force and res.done:
+ _log.warning(
+ "Attempting to cancel result %r that's already done, this is a no-op",
+ res.id,
+ )
+ return res
+
+ _log.info("Revoking request id=%r", res)
+
+ await res.revoke()
+
+ payload = MessagePayload(
+ id=str(uuid.uuid4()),
+ kind=Revoke.MESSAGE_KIND,
+ payload=Revoke(id=request_id),
+ )
+
+ await self.broker.send_control_message(payload)
+
+ child_count = await res.children.count()
+ if child_count:
+ _log.info("Revoking %r children of id=%r", child_count, res.id)
+
+ # Abort children.
+ async for child_id in res.children.iter_ids():
+ await self.revoke(child_id, force=force)
+
+ return res
+
start(self)
+
+
+ async
+
+
+Starts the app.
+ +mognet/app/app.py
async def start(self):
+ """
+ Starts the app.
+ """
+ _log.info("Starting app %r", self.config.node_id)
+
+ self._loop = asyncio.get_event_loop()
+
+ self._run_result = asyncio.Future()
+
+ self._log_tasks_and_queues()
+
+ await self._call_on_starting_middleware()
+
+ await self.connect()
+
+ self._heartbeat_task = asyncio.create_task(self._background_heartbeat())
+ self._consume_control_task = asyncio.create_task(self._consume_control_queue())
+
+ self.worker = Worker(app=self, middleware=self._middleware)
+ self._worker_task = asyncio.create_task(self.worker.run())
+
+ _log.info("Started")
+
+ await self._call_on_started_middleware()
+
+ return await self._run_result
+
submit(self, req, context=None)
+
+
+ async
+
+
+Submits a request for execution.
+If a context is defined, it will be used to create a parent-child +relationship between the new-to-submit request and the one existing +in the context instance. This is later used to cancel the whole task tree.
+ +mognet/app/app.py
async def submit(self, req: "Request", context: Optional[Context] = None) -> Result:
+ """
+ Submits a request for execution.
+
+ If a context is defined, it will be used to create a parent-child
+ relationship between the new-to-submit request and the one existing
+ in the context instance. This is later used to cancel the whole task tree.
+ """
+ if not self.result_backend:
+ raise ImproperlyConfigured("Result backend not defined")
+
+ if not self.broker:
+ raise ImproperlyConfigured("Broker not connected")
+
+ try:
+ if req.kwargs_repr is None:
+ req.kwargs_repr = format_kwargs_repr(req.args, req.kwargs)
+ _log.debug("Set default kwargs_repr on Request %r", req)
+
+ res = Result(
+ self.result_backend,
+ id=req.id,
+ name=req.name,
+ state=ResultState.PENDING,
+ created=datetime.now(tz=timezone.utc),
+ request_kwargs_repr=req.kwargs_repr,
+ )
+
+ if context is not None:
+ # Set the parent-child relationship and update the request stack.
+ parent_request = context.request
+ res.parent_id = parent_request.id
+
+ req.stack = [*parent_request.stack, parent_request.id]
+
+ if res.parent_id is not None:
+ await self.result_backend.add_children(res.parent_id, req.id)
+
+ await self.result_backend.set(req.id, res)
+
+ # Store the metadata on the Result.
+ if req.metadata:
+ await res.set_metadata(**req.metadata)
+
+ await self._on_submitting(req, context=context)
+
+ payload = MessagePayload(
+ id=str(req.id),
+ kind="Request",
+ payload=req,
+ priority=req.priority,
+ )
+
+ _log.debug("Sending message %r", payload.id)
+
+ await self.broker.send_task_message(self._get_task_route(req), payload)
+
+ return res
+ except Exception as exc:
+ raise CouldNotSubmit(f"Could not submit {req!r}") from exc
+
app_config
+
+
+
+
+AppConfig (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+Configuration for a Mognet application.
+ +mognet/app/app_config.py
class AppConfig(BaseModel):
+ """
+ Configuration for a Mognet application.
+ """
+
+ # An ID for the node. Defaults to a string containing
+ # the current PID and the hostname.
+ node_id: str = Field(default_factory=_default_node_id)
+
+ # Configuration for the result backend.
+ result_backend: ResultBackendConfig
+
+ # Configuration for the state backend.
+ state_backend: StateBackendConfig
+
+ # Configuration for the task broker.
+ broker: BrokerConfig
+
+ # List of modules to import
+ imports: List[str] = Field(default_factory=list)
+
+ # Maximum number of tasks that this app can handle.
+ max_tasks: Optional[int] = None
+
+ # Maximum recursion depth for tasks that call other tasks.
+ max_recursion: int = 64
+
+ # Defines the number of times a task that unexpectedly
+ # failed (i.e., SIGKILL) can be retried.
+ max_retries: int = 3
+
+ # Default task route to send messages to.
+ default_task_route: str = "tasks"
+
+ # A mapping of [task name] -> [queue] that overrides the queue on which a task is listening.
+ # If a task is not here, it will default to the queue set in [default_task_route].
+ task_routes: Dict[str, str] = Field(default_factory=dict)
+
+ # Specify which queues to listen, or not listen, on.
+ task_queues: Queues = Field(default_factory=Queues)
+
+ # The minimum prefetch count. Task consumption will start with
+ # this value, and is then incremented based on the number of waiting
+ # tasks that are running.
+ # A higher value allows more tasks to run concurrently on this node.
+ minimum_concurrency: int = 1
+
+ # The minimum prefetch count. This helps ensure that not too many
+ # recursive tasks run on this node.
+ # Bear in mind that, if set, you can run into deadlocks if you have
+ # overly recursive tasks.
+ maximum_concurrency: Optional[int] = None
+
+ # Settings that can be passed to instances retrieved via
+ # Context#get_service()
+ services_settings: Dict[str, Any] = Field(default_factory=dict)
+
+ @classmethod
+ def from_file(cls, file_path: str) -> "AppConfig":
+ with open(file_path, "r", encoding="utf-8") as config_file:
+ return cls.parse_raw(config_file.read())
+
+ # Maximum number of attempts to connect
+ max_reconnect_retries: int = 5
+
+ # Time to wait between reconnects
+ reconnect_interval: float = 5
+
backend
+
+
+
+ special
+
+
+Result Backends are used to retrieve Task Results from a persistent storage backend.
+ + + +backend_config
+
+
+
+
+Encoding (str, Enum)
+
+
+
+
+
+RedisResultBackendSettings (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+Configuration for the Redis Result Backend
+ +mognet/backend/backend_config.py
class RedisResultBackendSettings(BaseModel):
+ """Configuration for the Redis Result Backend"""
+
+ url: str = "redis://localhost:6379/"
+
+ # TTL for the results.
+ result_ttl: Optional[int] = int(timedelta(days=21).total_seconds())
+
+ # TTL for the result values. This is set lower than `result_ttl` to keep
+ # the results themselves available for longer.
+ result_value_ttl: Optional[int] = int(timedelta(days=7).total_seconds())
+
+ # Encoding for the result values.
+ result_value_encoding: Optional[Encoding] = Encoding.GZIP
+
+ retry_connect_attempts: int = 10
+ retry_connect_timeout: float = 30
+
+ # Set the limit of connections on the Redis connection pool.
+ # DANGER! Setting this to too low a value WILL cause issues opening connections!
+ max_connections: Optional[int] = None
+
base_result_backend
+
+
+
+
+BaseResultBackend
+
+
+
+Base interface to implemenent a Result Backend.
+ +mognet/backend/base_result_backend.py
class BaseResultBackend(metaclass=ABCMeta):
+ """Base interface to implemenent a Result Backend."""
+
+ config: ResultBackendConfig
+ app: AppParameters
+
+ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
+ super().__init__()
+
+ self.config = config
+ self.app = app
+
+ @abstractmethod
+ async def get(self, result_id: UUID) -> Optional[Result]:
+ """
+ Get a Result by it's ID.
+ If it doesn't exist, this method returns None.
+ """
+ raise NotImplementedError
+
+ async def get_many(self, *result_ids: UUID) -> List[Result]:
+ """
+ Get a list of Results by specifying their IDs.
+ Results that don't exist will be removed from this list.
+ """
+ all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])
+
+ return [r for r in all_results if r if r is not None]
+
+ async def get_or_create(self, result_id: UUID) -> Result:
+ """
+ Get a Result by it's ID.
+ If it doesn't exist, this method creates one.
+
+ The returned Result will either be the existing one,
+ or the newly-created one.
+ """
+ res = await self.get(result_id)
+
+ if res is None:
+ res = Result(self, id=result_id)
+ await self.set(result_id, res)
+
+ return res
+
+ @abstractmethod
+ async def set(self, result_id: UUID, result: Result) -> None:
+ """
+ Save a Result.
+ """
+ raise NotImplementedError
+
+ async def wait(
+ self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
+ ) -> Result:
+ """
+ Wait until a result is ready.
+
+ Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
+ """
+
+ async def waiter():
+ while True:
+ result = await self.get(result_id)
+
+ if result is not None and result.done:
+ return result
+
+ await asyncio.sleep(poll)
+
+ if timeout:
+ return await asyncio.wait_for(waiter(), timeout)
+
+ return await waiter()
+
+ @abstractmethod
+ async def get_children_count(self, parent_result_id: UUID) -> int:
+ """
+ Return the number of children of a Result.
+
+ Returns 0 if the Result doesn't exist.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ def iterate_children_ids(
+ self, parent_result_id: UUID, *, count: Optional[int] = None
+ ) -> AsyncGenerator[UUID, None]:
+ """
+ Get an AsyncGenerator for the IDs for the children of a Result.
+
+ The AsyncGenerator will be empty if the Result doesn't exist.
+ """
+ raise NotImplementedError
+
+ def iterate_children(
+ self, parent_result_id: UUID, *, count: Optional[int] = None
+ ) -> AsyncGenerator[Result, None]:
+ """
+ Get an AsyncGenerator for the children of a Result.
+
+ The AsyncGenerator will be empty if the Result doesn't exist.
+ """
+ raise NotImplementedError
+
+ async def __aenter__(self):
+ return self
+
+ async def __aexit__(self, *args, **kwargs):
+ return None
+
+ async def connect(self):
+ """
+ Explicit method to connect to the backend provided by
+ this Result backend.
+ """
+
+ async def close(self):
+ """
+ Explicit method to close the backend provided by
+ this Result backend.
+ """
+
+ @abstractmethod
+ async def add_children(self, result_id: UUID, *children: UUID) -> None:
+ """
+ Add children to a parent Result.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def get_value(self, result_id: UUID) -> ResultValueHolder:
+ """
+ Get the value of a Result.
+
+ If the value is lost, ResultValueLost is raised.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def set_value(self, result_id: UUID, value: ResultValueHolder) -> None:
+ """
+ Set the value of a Result.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
+ """
+ Get the metadata of a Result.
+
+ Returns an empty Dict if the Result doesn't exist.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
+ """
+ Set metadata on a Result.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def delete(self, result_id: UUID, include_children: bool = True) -> None:
+ """
+ Delete a Result.
+ """
+ raise NotImplementedError
+
+ @abstractmethod
+ async def set_ttl(
+ self, result_id: UUID, ttl: timedelta, include_children: bool = True
+ ) -> None:
+ """
+ Set expiration on a Result.
+
+ If include_children is True, children will have the same TTL set.
+ """
+ raise NotImplementedError
+
add_children(self, result_id, *children)
+
+
+ async
+
+
+close(self)
+
+
+ async
+
+
+connect(self)
+
+
+ async
+
+
+delete(self, result_id, include_children=True)
+
+
+ async
+
+
+get(self, result_id)
+
+
+ async
+
+
+Get a Result by it's ID. +If it doesn't exist, this method returns None.
+ + +get_children_count(self, parent_result_id)
+
+
+ async
+
+
+Return the number of children of a Result.
+Returns 0 if the Result doesn't exist.
+ + +get_many(self, *result_ids)
+
+
+ async
+
+
+Get a list of Results by specifying their IDs. +Results that don't exist will be removed from this list.
+ +mognet/backend/base_result_backend.py
async def get_many(self, *result_ids: UUID) -> List[Result]:
+ """
+ Get a list of Results by specifying their IDs.
+ Results that don't exist will be removed from this list.
+ """
+ all_results = await asyncio.gather(*[self.get(r_id) for r_id in result_ids])
+
+ return [r for r in all_results if r if r is not None]
+
get_metadata(self, result_id)
+
+
+ async
+
+
+Get the metadata of a Result.
+Returns an empty Dict if the Result doesn't exist.
+ + +get_or_create(self, result_id)
+
+
+ async
+
+
+Get a Result by it's ID. +If it doesn't exist, this method creates one.
+The returned Result will either be the existing one, +or the newly-created one.
+ +mognet/backend/base_result_backend.py
async def get_or_create(self, result_id: UUID) -> Result:
+ """
+ Get a Result by it's ID.
+ If it doesn't exist, this method creates one.
+
+ The returned Result will either be the existing one,
+ or the newly-created one.
+ """
+ res = await self.get(result_id)
+
+ if res is None:
+ res = Result(self, id=result_id)
+ await self.set(result_id, res)
+
+ return res
+
get_value(self, result_id)
+
+
+ async
+
+
+Get the value of a Result.
+If the value is lost, ResultValueLost is raised.
+ + +iterate_children(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/base_result_backend.py
iterate_children_ids(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the IDs for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/base_result_backend.py
@abstractmethod
+def iterate_children_ids(
+ self, parent_result_id: UUID, *, count: Optional[int] = None
+) -> AsyncGenerator[UUID, None]:
+ """
+ Get an AsyncGenerator for the IDs for the children of a Result.
+
+ The AsyncGenerator will be empty if the Result doesn't exist.
+ """
+ raise NotImplementedError
+
set(self, result_id, result)
+
+
+ async
+
+
+set_metadata(self, result_id, **kwargs)
+
+
+ async
+
+
+set_ttl(self, result_id, ttl, include_children=True)
+
+
+ async
+
+
+Set expiration on a Result.
+If include_children is True, children will have the same TTL set.
+ +mognet/backend/base_result_backend.py
set_value(self, result_id, value)
+
+
+ async
+
+
+wait(self, result_id, timeout=None, poll=0.1)
+
+
+ async
+
+
+Wait until a result is ready.
+Raises asyncio.TimeoutError
if a timeout is set and exceeded.
mognet/backend/base_result_backend.py
async def wait(
+ self, result_id: UUID, timeout: Optional[float] = None, poll: float = 0.1
+) -> Result:
+ """
+ Wait until a result is ready.
+
+ Raises `asyncio.TimeoutError` if a timeout is set and exceeded.
+ """
+
+ async def waiter():
+ while True:
+ result = await self.get(result_id)
+
+ if result is not None and result.done:
+ return result
+
+ await asyncio.sleep(poll)
+
+ if timeout:
+ return await asyncio.wait_for(waiter(), timeout)
+
+ return await waiter()
+
memory_result_backend
+
+
+
+
+MemoryResultBackend (BaseResultBackend)
+
+
+
+
+Result backend that "persists" results in memory. Useful for testing, +but this is not recommended for production setups.
+ +mognet/backend/memory_result_backend.py
class MemoryResultBackend(BaseResultBackend):
+ """
+ Result backend that "persists" results in memory. Useful for testing,
+ but this is not recommended for production setups.
+ """
+
+ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
+ super().__init__(config, app)
+
+ self._results: Dict[UUID, Result] = {}
+ self._result_tree: Dict[UUID, Set[UUID]] = {}
+ self._values: Dict[UUID, ResultValueHolder] = {}
+ self._metadata: Dict[UUID, Dict[str, Any]] = {}
+
+ async def get(self, result_id: UUID) -> Optional[Result]:
+ return self._results.get(result_id, None)
+
+ async def set(self, result_id: UUID, result: Result):
+ self._results[result_id] = result
+
+ async def get_children_count(self, parent_result_id: UUID) -> int:
+ return len(self._result_tree.get(parent_result_id, set()))
+
+ async def iterate_children_ids(
+ self, parent_result_id: UUID, *, count: int = None
+ ) -> AsyncGenerator[UUID, None]:
+ children = self._result_tree[parent_result_id]
+
+ for idx, child in enumerate(children):
+ yield child
+
+ if count is not None and idx > count:
+ break
+
+ async def iterate_children(
+ self, parent_result_id: UUID, *, count: int = None
+ ) -> AsyncGenerator[Result, None]:
+ async for child_id in self.iterate_children_ids(parent_result_id, count=count):
+ child = self._results.get(child_id, None)
+
+ if child is not None:
+ yield child
+
+ async def add_children(self, result_id: UUID, *children: UUID) -> None:
+ self._result_tree.setdefault(result_id, set()).update(children)
+
+ async def get_value(self, result_id: UUID) -> ResultValueHolder:
+ value = self._values.get(result_id, None)
+
+ if value is None:
+ raise ResultValueLost(result_id)
+
+ return value
+
+ async def set_value(self, result_id: UUID, value: ResultValueHolder):
+ self._values[result_id] = value
+
+ async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
+ meta = self._metadata.get(result_id, {})
+ return meta
+
+ async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
+ self._metadata.setdefault(result_id, {}).update(kwargs)
+
+ async def delete(self, result_id: UUID, include_children: bool = True):
+ if include_children:
+ for child_id in self._result_tree.get(result_id, set()):
+ await self.delete(child_id, include_children=include_children)
+
+ self._results.pop(result_id, None)
+ self._metadata.pop(result_id, None)
+ self._values.pop(result_id, None)
+
+ async def set_ttl(
+ self, result_id: UUID, ttl: timedelta, include_children: bool = True
+ ):
+ pass
+
+ async def close(self):
+ self._metadata = {}
+ self._result_tree = {}
+ self._results = {}
+ self._values = {}
+
+ return await super().close()
+
add_children(self, result_id, *children)
+
+
+ async
+
+
+close(self)
+
+
+ async
+
+
+delete(self, result_id, include_children=True)
+
+
+ async
+
+
+Delete a Result.
+ +mognet/backend/memory_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
+ if include_children:
+ for child_id in self._result_tree.get(result_id, set()):
+ await self.delete(child_id, include_children=include_children)
+
+ self._results.pop(result_id, None)
+ self._metadata.pop(result_id, None)
+ self._values.pop(result_id, None)
+
get(self, result_id)
+
+
+ async
+
+
+get_children_count(self, parent_result_id)
+
+
+ async
+
+
+get_metadata(self, result_id)
+
+
+ async
+
+
+get_value(self, result_id)
+
+
+ async
+
+
+Get the value of a Result.
+If the value is lost, ResultValueLost is raised.
+ + +iterate_children(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/memory_result_backend.py
iterate_children_ids(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the IDs for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/memory_result_backend.py
set(self, result_id, result)
+
+
+ async
+
+
+set_metadata(self, result_id, **kwargs)
+
+
+ async
+
+
+set_ttl(self, result_id, ttl, include_children=True)
+
+
+ async
+
+
+redis_result_backend
+
+
+
+
+RedisResultBackend (BaseResultBackend)
+
+
+
+
+Result backend that uses Redis for persistence.
+ +mognet/backend/redis_result_backend.py
class RedisResultBackend(BaseResultBackend):
+ """
+ Result backend that uses Redis for persistence.
+ """
+
+ def __init__(self, config: ResultBackendConfig, app: AppParameters) -> None:
+ super().__init__(config, app)
+
+ self._url = config.redis.url
+ self.__redis = None
+ self._connected = False
+
+ # Holds references to tasks which are spawned by .wait()
+ self._waiters: List[asyncio.Future] = []
+
+ # Attributes for @_retry
+ self._retry_connect_attempts = self.config.redis.retry_connect_attempts
+ self._retry_connect_timeout = self.config.redis.retry_connect_timeout
+ self._retry_lock = asyncio.Lock()
+
+ @property
+ def _redis(self) -> Redis:
+ if self.__redis is None:
+ raise NotConnected
+
+ return self.__redis
+
+ @_retry
+ async def get(self, result_id: UUID) -> Optional[Result]:
+ obj_key = self._format_key(result_id)
+
+ async with self._redis.pipeline(transaction=True) as pip:
+ # Since HGETALL returns an empty HASH for keys that don't exist,
+ # test if it exists at all and use that to check if we should return null.
+ pip.exists(obj_key)
+ pip.hgetall(obj_key)
+
+ exists, value, *_ = await shield(pip.execute())
+
+ if not exists:
+ return None
+
+ return self._decode_result(value)
+
+ @_retry
+ async def get_or_create(self, result_id: UUID) -> Result:
+ """
+ Gets a result, or creates one if it doesn't exist.
+ """
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ result_key = self._format_key(result_id)
+
+ pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
+ pip.hgetall(result_key)
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(result_key, self.config.redis.result_ttl)
+
+ # Also set the value, to a default holding an absence of result.
+ value_key = self._format_key(result_id, "value")
+
+ default_not_ready = ResultValueHolder.not_ready()
+ encoded = self._encode_result_value(default_not_ready)
+
+ if self.config.redis.result_value_ttl is not None:
+ pip.expire(value_key, self.config.redis.result_value_ttl)
+
+ for encoded_k, encoded_v in encoded.items():
+ pip.hsetnx(value_key, encoded_k, encoded_v)
+
+ existed, value, *_ = await shield(pip.execute())
+
+ if not existed:
+ _log.debug("Created result %r on key %r", result_id, result_key)
+
+ return self._decode_result(value)
+
+ def _encode_result_value(self, value: ResultValueHolder) -> Dict[str, bytes]:
+ contents = value.json().encode()
+ encoding = b"null"
+
+ if self.config.redis.result_value_encoding == Encoding.GZIP:
+ encoding = _json_bytes("gzip")
+ contents = gzip.compress(contents)
+
+ return {
+ "contents": contents,
+ "encoding": encoding,
+ "content_type": _json_bytes("application/json"),
+ }
+
+ def _decode_result_value(self, encoded: Dict[bytes, bytes]) -> ResultValueHolder:
+ if encoded.get(b"encoding") == _json_bytes("gzip"):
+ contents = gzip.decompress(encoded[b"contents"])
+ else:
+ contents = encoded[b"contents"]
+
+ if encoded.get(b"content_type") != _json_bytes("application/json"):
+ raise ValueError(f"Unknown content_type={encoded.get(b'content_type')!r}")
+
+ return ResultValueHolder.parse_raw(contents, content_type="application/json")
+
+ @_retry
+ async def set(self, result_id: UUID, result: Result):
+ key = self._format_key(result_id)
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ encoded = _encode_result(result)
+
+ pip.hset(key, None, None, encoded)
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(key, self.config.redis.result_ttl)
+
+ await shield(pip.execute())
+
+ def _format_key(self, result_id: UUID, subkey: str = None) -> str:
+ key = f"{self.app.name}.mognet.result.{str(result_id)}"
+
+ if subkey:
+ key = f"{key}/{subkey}"
+
+ _log.debug(
+ "Formatted result key=%r for id=%r and subkey=%r", key, subkey, result_id
+ )
+
+ return key
+
+ @_retry
+ async def add_children(self, result_id: UUID, *children: UUID):
+ if not children:
+ return
+
+ # If there are children to add, add them to the set
+ # on Redis using SADD
+ children_key = self._format_key(result_id, "children")
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.sadd(children_key, *_encode_children(children))
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(children_key, self.config.redis.result_ttl)
+
+ await shield(pip.execute())
+
+ async def get_value(self, result_id: UUID) -> ResultValueHolder:
+ value_key = self._format_key(result_id, "value")
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.exists(value_key)
+ pip.hgetall(value_key)
+
+ exists, contents = await shield(pip.execute())
+
+ if not exists:
+ raise ResultValueLost(result_id)
+
+ return self._decode_result_value(contents)
+
+ async def set_value(self, result_id: UUID, value: ResultValueHolder):
+ value_key = self._format_key(result_id, "value")
+
+ encoded = self._encode_result_value(value)
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.hset(value_key, None, None, encoded)
+
+ if self.config.redis.result_value_ttl is not None:
+ pip.expire(value_key, self.config.redis.result_value_ttl)
+
+ await shield(pip.execute())
+
+ async def delete(self, result_id: UUID, include_children: bool = True):
+ if include_children:
+ async for child_id in self.iterate_children_ids(result_id):
+ await self.delete(child_id, include_children=True)
+
+ key = self._format_key(result_id)
+ children_key = self._format_key(result_id, "children")
+ value_key = self._format_key(result_id, "value")
+ metadata_key = self._format_key(result_id, "metadata")
+
+ await shield(self._redis.delete(key, children_key, value_key, metadata_key))
+
+ async def set_ttl(
+ self, result_id: UUID, ttl: timedelta, include_children: bool = True
+ ):
+ if include_children:
+ async for child_id in self.iterate_children_ids(result_id):
+ await self.set_ttl(child_id, ttl, include_children=True)
+
+ key = self._format_key(result_id)
+ children_key = self._format_key(result_id, "children")
+ value_key = self._format_key(result_id, "value")
+ metadata_key = self._format_key(result_id, "metadata")
+
+ await shield(self._redis.expire(key, ttl))
+ await shield(self._redis.expire(children_key, ttl))
+ await shield(self._redis.expire(value_key, ttl))
+ await shield(self._redis.expire(metadata_key, ttl))
+
+ async def connect(self):
+ if self._connected:
+ return
+
+ self._connected = True
+
+ await self._connect()
+
+ async def close(self):
+ self._connected = False
+
+ await self._close_waiters()
+
+ await self._disconnect()
+
+ async def get_children_count(self, parent_result_id: UUID) -> int:
+ children_key = self._format_key(parent_result_id, "children")
+
+ return await shield(self._redis.scard(children_key))
+
+ async def iterate_children_ids(
+ self, parent_result_id: UUID, *, count: Optional[float] = None
+ ):
+ children_key = self._format_key(parent_result_id, "children")
+
+ raw_child_id: bytes
+ async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
+ child_id = UUID(bytes=raw_child_id)
+ yield child_id
+
+ async def iterate_children(
+ self, parent_result_id: UUID, *, count: Optional[float] = None
+ ):
+ async for child_id in self.iterate_children_ids(parent_result_id, count=count):
+ child = await self.get(child_id)
+
+ if child is not None:
+ yield child
+
+ @_retry
+ async def wait(
+ self, result_id: UUID, timeout: Optional[float] = None, poll: float = 1
+ ) -> Result:
+ async def waiter():
+ key = self._format_key(result_id=result_id)
+
+ # Type def for the state key. It can (but shouldn't)
+ # be null.
+ t = Optional[ResultState]
+
+ while True:
+ raw_state = await shield(self._redis.hget(key, "state")) or b"null"
+
+ state = parse_raw_as(t, raw_state)
+
+ if state is None:
+ raise ResultValueLost(result_id)
+
+ if state in READY_STATES:
+ final_result = await self.get(result_id)
+
+ if final_result is None:
+ raise RuntimeError(
+ f"Result id={result_id!r} that previously existed no longer does"
+ )
+
+ return final_result
+
+ await asyncio.sleep(poll)
+
+ waiter_task = asyncio.create_task(
+ waiter(),
+ name=f"RedisResultBackend:wait_for:{result_id}",
+ )
+
+ if timeout:
+ waiter_task = asyncio.create_task(asyncio.wait_for(waiter_task, timeout))
+
+ self._waiters.append(waiter_task)
+
+ return await waiter_task
+
+ async def get_metadata(self, result_id: UUID) -> Dict[str, Any]:
+ key = self._format_key(result_id, "metadata")
+
+ value = await shield(self._redis.hgetall(key))
+
+ return _decode_json_dict(value)
+
+ async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
+ key = self._format_key(result_id, "metadata")
+
+ if not kwargs:
+ return
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.hset(key, None, None, _dict_to_json_dict(kwargs))
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(key, self.config.redis.result_ttl)
+
+ await shield(pip.execute())
+
+ def __repr__(self):
+ return f"RedisResultBackend(url={censor_credentials(self._url)!r})"
+
+ async def __aenter__(self):
+ await self.connect()
+
+ return self
+
+ async def __aexit__(self, *args, **kwargs):
+ await self.close()
+
+ async def _close_waiters(self):
+ """
+ Cancel any wait loop we have running.
+ """
+ while self._waiters:
+ waiter_task = self._waiters.pop()
+
+ try:
+ _log.debug("Cancelling waiter %r", waiter_task)
+
+ waiter_task.cancel()
+ await waiter_task
+ except asyncio.CancelledError:
+ pass
+ except Exception as exc: # pylint: disable=broad-except
+ _log.debug("Error on waiter task %r", waiter_task, exc_info=exc)
+
+ async def _create_redis(self):
+ _log.debug("Creating Redis connection")
+ redis: Redis = await from_url(
+ self._url,
+ max_connections=self.config.redis.max_connections,
+ )
+
+ return redis
+
+ @_retry
+ async def _connect(self):
+ if self.__redis is None:
+ self.__redis = await self._create_redis()
+
+ await shield(self._redis.ping())
+
+ async def _disconnect(self):
+ redis = self.__redis
+
+ if redis is not None:
+ self.__redis = None
+ _log.debug("Closing Redis connection")
+ await redis.close()
+
+ def _decode_result(self, json_dict: Dict[bytes, bytes]) -> Result:
+ # Load the dict of JSON values first; then update it with overrides.
+ value = _decode_json_dict(json_dict)
+ return Result(self, **value)
+
close(self)
+
+
+ async
+
+
+connect(self)
+
+
+ async
+
+
+delete(self, result_id, include_children=True)
+
+
+ async
+
+
+Delete a Result.
+ +mognet/backend/redis_result_backend.py
async def delete(self, result_id: UUID, include_children: bool = True):
+ if include_children:
+ async for child_id in self.iterate_children_ids(result_id):
+ await self.delete(child_id, include_children=True)
+
+ key = self._format_key(result_id)
+ children_key = self._format_key(result_id, "children")
+ value_key = self._format_key(result_id, "value")
+ metadata_key = self._format_key(result_id, "metadata")
+
+ await shield(self._redis.delete(key, children_key, value_key, metadata_key))
+
get_children_count(self, parent_result_id)
+
+
+ async
+
+
+Return the number of children of a Result.
+Returns 0 if the Result doesn't exist.
+ + +get_metadata(self, result_id)
+
+
+ async
+
+
+Get the metadata of a Result.
+Returns an empty Dict if the Result doesn't exist.
+ + +get_or_create(self, result_id)
+
+
+ async
+
+
+Gets a result, or creates one if it doesn't exist.
+ +mognet/backend/redis_result_backend.py
@_retry
+async def get_or_create(self, result_id: UUID) -> Result:
+ """
+ Gets a result, or creates one if it doesn't exist.
+ """
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ result_key = self._format_key(result_id)
+
+ pip.hsetnx(result_key, "id", json.dumps(str(result_id)).encode())
+ pip.hgetall(result_key)
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(result_key, self.config.redis.result_ttl)
+
+ # Also set the value, to a default holding an absence of result.
+ value_key = self._format_key(result_id, "value")
+
+ default_not_ready = ResultValueHolder.not_ready()
+ encoded = self._encode_result_value(default_not_ready)
+
+ if self.config.redis.result_value_ttl is not None:
+ pip.expire(value_key, self.config.redis.result_value_ttl)
+
+ for encoded_k, encoded_v in encoded.items():
+ pip.hsetnx(value_key, encoded_k, encoded_v)
+
+ existed, value, *_ = await shield(pip.execute())
+
+ if not existed:
+ _log.debug("Created result %r on key %r", result_id, result_key)
+
+ return self._decode_result(value)
+
get_value(self, result_id)
+
+
+ async
+
+
+Get the value of a Result.
+If the value is lost, ResultValueLost is raised.
+ +mognet/backend/redis_result_backend.py
async def get_value(self, result_id: UUID) -> ResultValueHolder:
+ value_key = self._format_key(result_id, "value")
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.exists(value_key)
+ pip.hgetall(value_key)
+
+ exists, contents = await shield(pip.execute())
+
+ if not exists:
+ raise ResultValueLost(result_id)
+
+ return self._decode_result_value(contents)
+
iterate_children(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/redis_result_backend.py
iterate_children_ids(self, parent_result_id, *, count=None)
+
+
+Get an AsyncGenerator for the IDs for the children of a Result.
+The AsyncGenerator will be empty if the Result doesn't exist.
+ +mognet/backend/redis_result_backend.py
async def iterate_children_ids(
+ self, parent_result_id: UUID, *, count: Optional[float] = None
+):
+ children_key = self._format_key(parent_result_id, "children")
+
+ raw_child_id: bytes
+ async for raw_child_id in self._redis.sscan_iter(children_key, count=count):
+ child_id = UUID(bytes=raw_child_id)
+ yield child_id
+
set_metadata(self, result_id, **kwargs)
+
+
+ async
+
+
+Set metadata on a Result.
+ +mognet/backend/redis_result_backend.py
async def set_metadata(self, result_id: UUID, **kwargs: Any) -> None:
+ key = self._format_key(result_id, "metadata")
+
+ if not kwargs:
+ return
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.hset(key, None, None, _dict_to_json_dict(kwargs))
+
+ if self.config.redis.result_ttl is not None:
+ pip.expire(key, self.config.redis.result_ttl)
+
+ await shield(pip.execute())
+
set_ttl(self, result_id, ttl, include_children=True)
+
+
+ async
+
+
+Set expiration on a Result.
+If include_children is True, children will have the same TTL set.
+ +mognet/backend/redis_result_backend.py
async def set_ttl(
+ self, result_id: UUID, ttl: timedelta, include_children: bool = True
+):
+ if include_children:
+ async for child_id in self.iterate_children_ids(result_id):
+ await self.set_ttl(child_id, ttl, include_children=True)
+
+ key = self._format_key(result_id)
+ children_key = self._format_key(result_id, "children")
+ value_key = self._format_key(result_id, "value")
+ metadata_key = self._format_key(result_id, "metadata")
+
+ await shield(self._redis.expire(key, ttl))
+ await shield(self._redis.expire(children_key, ttl))
+ await shield(self._redis.expire(value_key, ttl))
+ await shield(self._redis.expire(metadata_key, ttl))
+
set_value(self, result_id, value)
+
+
+ async
+
+
+Set the value of a Result.
+ +mognet/backend/redis_result_backend.py
async def set_value(self, result_id: UUID, value: ResultValueHolder):
+ value_key = self._format_key(result_id, "value")
+
+ encoded = self._encode_result_value(value)
+
+ async with self._redis.pipeline(transaction=True) as pip:
+
+ pip.hset(value_key, None, None, encoded)
+
+ if self.config.redis.result_value_ttl is not None:
+ pip.expire(value_key, self.config.redis.result_value_ttl)
+
+ await shield(pip.execute())
+
broker
+
+
+
+ special
+
+
+amqp_broker
+
+
+
+
+AmqpBroker (BaseBroker)
+
+
+
+
+mognet/broker/amqp_broker.py
class AmqpBroker(BaseBroker):
+
+ _task_channel: Channel
+ _control_channel: Channel
+
+ _task_queues: Dict[str, Queue]
+
+ _direct_exchange: Exchange
+ _control_exchange: Exchange
+
+ _retry = retryableasyncmethod(
+ _RETRYABLE_ERRORS,
+ max_attempts="_retry_connect_attempts",
+ wait_timeout="_retry_connect_timeout",
+ )
+
+ def __init__(self, app: "App", config: BrokerConfig) -> None:
+ super().__init__()
+
+ self._connected = False
+ self.__connection = None
+
+ self.config = config
+
+ self._task_queues = {}
+ self._control_queue = None
+
+ # Lock to prevent duplicate queue declaration
+ self._lock = Lock()
+
+ self.app = app
+
+ # Attributes for @retryableasyncmethod
+ self._retry_connect_attempts = self.config.amqp.retry_connect_attempts
+ self._retry_connect_timeout = self.config.amqp.retry_connect_timeout
+
+ # List of callbacks for when connection drops
+ self._on_connection_failed_callbacks: List[
+ Callable[[Optional[BaseException]], Awaitable]
+ ] = []
+
+ @property
+ def _connection(self) -> Connection:
+ if self.__connection is None:
+ raise NotConnected
+
+ return self.__connection
+
+ async def ack(self, delivery_tag: str):
+ await self._task_channel.channel.basic_ack(delivery_tag)
+
+ async def nack(self, delivery_tag: str):
+ await self._task_channel.channel.basic_nack(delivery_tag)
+
+ @_retry
+ async def set_task_prefetch(self, prefetch: int):
+ await self._task_channel.set_qos(prefetch_count=prefetch, global_=True)
+
+ @_retry
+ async def send_task_message(self, queue: str, payload: MessagePayload):
+ amqp_queue = self._task_queue_name(queue)
+
+ msg = Message(
+ body=payload.json().encode(),
+ content_type="application/json",
+ content_encoding="utf-8",
+ priority=payload.priority,
+ message_id=payload.id,
+ )
+
+ await self._direct_exchange.publish(msg, amqp_queue)
+
+ _log.debug(
+ "Message %r sent to queue=%r (amqp queue=%r)", payload.id, queue, amqp_queue
+ )
+
+ async def consume_tasks(
+ self, queue: str
+ ) -> AsyncGenerator[IncomingMessagePayload, None]:
+
+ amqp_queue = await self._get_or_create_task_queue(TaskQueue(name=queue))
+
+ async for message in self._consume(amqp_queue):
+ yield message
+
+ async def consume_control_queue(
+ self,
+ ) -> AsyncGenerator[IncomingMessagePayload, None]:
+
+ amqp_queue = await self._get_or_create_control_queue()
+
+ async for message in self._consume(amqp_queue):
+ yield message
+
+ @_retry
+ async def send_control_message(self, payload: MessagePayload):
+ msg = Message(
+ body=payload.json().encode(),
+ content_type="application/json",
+ content_encoding="utf-8",
+ message_id=payload.id,
+ expiration=timedelta(seconds=300),
+ )
+
+ # No queue name set because this is a fanout exchange.
+ await self._control_exchange.publish(msg, "")
+
+ @_retry
+ async def _send_query_message(self, payload: MessagePayload):
+ callback_queue = await self._task_channel.declare_queue(
+ name=self._callback_queue_name,
+ durable=False,
+ exclusive=False,
+ auto_delete=True,
+ arguments={
+ "x-expires": 30000,
+ "x-message-ttl": 30000,
+ },
+ )
+ await callback_queue.bind(self._direct_exchange)
+
+ msg = Message(
+ body=payload.json().encode(),
+ content_type="application/json",
+ content_encoding="utf-8",
+ message_id=payload.id,
+ expiration=timedelta(seconds=300),
+ reply_to=callback_queue.name,
+ )
+
+ await self._control_exchange.publish(msg, "")
+
+ return callback_queue
+
+ async def send_query_message(
+ self, payload: MessagePayload
+ ) -> AsyncGenerator[QueryResponseMessage, None]:
+
+ # Create a callback queue for getting the replies,
+ # then send the message to the control exchange (fanout).
+ # When done, delete the callback queue.
+
+ callback_queue = None
+ try:
+ callback_queue = await self._send_query_message(payload)
+
+ async with callback_queue.iterator() as iterator:
+ async for message in iterator:
+ async with message.process():
+ contents: dict = json.loads(message.body)
+ msg = _AmqpIncomingMessagePayload(
+ broker=self, incoming_message=message, **contents
+ )
+ yield QueryResponseMessage.parse_obj(msg.payload)
+ finally:
+ if callback_queue is not None:
+ await callback_queue.delete()
+
+ async def setup_control_queue(self):
+ await self._get_or_create_control_queue()
+
+ async def setup_task_queue(self, queue: TaskQueue):
+ await self._get_or_create_task_queue(queue)
+
+ @_retry
+ async def _create_connection(self):
+ connection = await aio_pika.connect_robust(
+ self.config.amqp.url,
+ reconnect_interval=self.app.config.reconnect_interval,
+ client_properties={
+ "connection_name": self.app.config.node_id,
+ },
+ )
+
+ # All callback for broadcasting unexpected connection drops
+ connection.add_close_callback(self._send_connection_failed_events)
+
+ return connection
+
+ def add_connection_failed_callback(
+ self, cb: Callable[[Optional[BaseException]], Awaitable]
+ ):
+ self._on_connection_failed_callbacks.append(cb)
+
+ def _send_connection_failed_events(self, connection, exc=None):
+ if not self._connected:
+ _log.debug(
+ "Not sending connection closed events because we are disconnected"
+ )
+ return
+
+ _log.error("AMQP connection %r failed", connection, exc_info=exc)
+
+ tasks = [cb(exc) for cb in self._on_connection_failed_callbacks]
+
+ _log.info(
+ "Notifying %r listeners of a disconnect",
+ len(tasks),
+ )
+
+ def notify_task_completion_callback(fut: asyncio.Future):
+ exc = fut.exception()
+
+ if exc and not fut.cancelled():
+ _log.error("Error notifying connection dropped", exc_info=exc)
+
+ for task in tasks:
+ notify_task = asyncio.create_task(task)
+ notify_task.add_done_callback(notify_task_completion_callback)
+
+ async def connect(self):
+ if self._connected:
+ return
+
+ self._connected = True
+
+ self.__connection = await self._create_connection()
+
+ # Use two separate channels with separate prefetch counts.
+ # This allows the task channel to increase the prefetch count
+ # without affecting the control channel,
+ # and allows the control channel to still receive messages, even if
+ # the task channel has reached the full prefetch count.
+ self._task_channel = await self._connection.channel()
+ await self.set_task_prefetch(1)
+
+ self._control_channel = await self._connection.channel()
+ await self.set_control_prefetch(4)
+
+ await self._create_exchanges()
+
+ _log.debug("Connected")
+
+ async def set_control_prefetch(self, prefetch: int):
+ await self._control_channel.set_qos(prefetch_count=prefetch, global_=False)
+
+ async def close(self):
+ self._connected = False
+
+ connection = self.__connection
+
+ if connection is not None:
+ self.__connection = None
+
+ _log.debug("Closing connections")
+ await connection.close()
+ _log.debug("Connection closed")
+
+ @_retry
+ async def send_reply(self, message: IncomingMessagePayload, reply: MessagePayload):
+ if not message.reply_to:
+ raise ValueError("Message has no reply_to set")
+
+ msg = Message(
+ body=reply.json().encode(),
+ content_type="application/json",
+ content_encoding="utf-8",
+ message_id=reply.id,
+ )
+
+ await self._direct_exchange.publish(msg, message.reply_to)
+
+ async def purge_task_queue(self, queue: str) -> int:
+ amqp_queue = self._task_queue_name(queue)
+
+ if amqp_queue not in self._task_queues:
+ _log.warning(
+ "Queue %r (amqp=%r) does not exist in this broker", queue, amqp_queue
+ )
+ return 0
+
+ result = await self._task_queues[amqp_queue].purge()
+
+ deleted_count: int = result.message_count
+
+ _log.info(
+ "Deleted %r messages from queue=%r (amqp=%r)",
+ deleted_count,
+ queue,
+ amqp_queue,
+ )
+
+ return deleted_count
+
+ async def purge_control_queue(self) -> int:
+ if not self._control_queue:
+ _log.debug("Not listening on any control queue, not purging it")
+ return 0
+
+ result = await self._control_queue.purge()
+
+ return result.message_count
+
+ def __repr__(self):
+ return f"AmqpBroker(url={censor_credentials(self.config.amqp.url)!r})"
+
+ async def __aenter__(self):
+ await self.connect()
+
+ return self
+
+ async def __aexit__(self, *args, **kwargs):
+ await self.close()
+
+ return None
+
+ async def _create_exchanges(self):
+ self._direct_exchange = await self._task_channel.declare_exchange(
+ self._direct_exchange_name,
+ type=ExchangeType.DIRECT,
+ durable=True,
+ )
+ self._control_exchange = await self._control_channel.declare_exchange(
+ self._control_exchange_name,
+ type=ExchangeType.FANOUT,
+ )
+
+ async def _consume(
+ self, amqp_queue: Queue
+ ) -> AsyncGenerator[IncomingMessagePayload, None]:
+
+ async with amqp_queue.iterator() as queue_iterator:
+ msg: IncomingMessage
+ async for msg in queue_iterator:
+
+ try:
+ contents: dict = json.loads(msg.body)
+
+ payload = _AmqpIncomingMessagePayload(
+ broker=self, incoming_message=msg, **contents
+ )
+
+ _log.debug("Successfully parsed message %r", payload.id)
+
+ yield payload
+ except asyncio.CancelledError:
+ raise
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Error parsing contents of message %r, discarding",
+ msg.correlation_id,
+ exc_info=exc,
+ )
+
+ try:
+ await asyncio.shield(msg.ack())
+ except Exception as ack_err:
+ _log.error(
+ "Could not ACK message %r for discarding",
+ msg.correlation_id,
+ exc_info=ack_err,
+ )
+
+ def _task_queue_name(self, name: str) -> str:
+ return f"{self.app.name}.{name}"
+
+ @property
+ def _control_queue_name(self) -> str:
+ return f"{self.app.name}.mognet.control.{self.app.config.node_id}"
+
+ @property
+ def _callback_queue_name(self) -> str:
+ return f"{self.app.name}.mognet.callback.{self.app.config.node_id}"
+
+ @property
+ def _control_exchange_name(self) -> str:
+ return f"{self.app.name}.mognet.control"
+
+ @property
+ def _direct_exchange_name(self) -> str:
+ return f"{self.app.name}.mognet.direct"
+
+ @_retry
+ async def _get_or_create_control_queue(self) -> Queue:
+ if self._control_queue is None:
+ async with self._lock:
+ if self._control_queue is None:
+ self._control_queue = await self._control_channel.declare_queue(
+ name=self._control_queue_name,
+ durable=False,
+ auto_delete=True,
+ arguments={
+ "x-expires": 30000,
+ "x-message-ttl": 30000,
+ },
+ )
+ await self._control_queue.bind(self._control_exchange)
+
+ _log.debug("Prepared control queue=%r", self._control_queue.name)
+
+ return self._control_queue
+
+ @_retry
+ async def _get_or_create_task_queue(self, queue: TaskQueue) -> Queue:
+ name = self._task_queue_name(queue.name)
+
+ if name not in self._task_queues:
+ async with self._lock:
+
+ if name not in self._task_queues:
+ _log.debug("Preparing queue %r as AMQP queue=%r", queue, name)
+
+ self._task_queues[name] = await self._task_channel.declare_queue(
+ name,
+ durable=True,
+ arguments={"x-max-priority": queue.max_priority},
+ )
+
+ await self._task_queues[name].bind(self._direct_exchange)
+
+ _log.debug(
+ "Prepared task queue=%r as AMQP queue=%r", queue.name, name
+ )
+
+ return self._task_queues[name]
+
+ async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
+ """
+ Get the stats of a task queue.
+ """
+
+ name = self._task_queue_name(task_queue_name)
+
+ # AMQP can close the Channel on us if we try accessing an object
+ # that doesn't exist. So, to avoid trouble when the same Channel
+ # is being used to consume a queue, use an ephemeral channel
+ # for this operation alone.
+ async with self._connection.channel() as channel:
+ try:
+ queue = await channel.get_queue(name, ensure=False)
+
+ declare_result = await queue.declare()
+ except (
+ aiormq.exceptions.ChannelNotFoundEntity,
+ aio_pika.exceptions.ChannelClosed,
+ ) as query_err:
+ raise QueueNotFound(task_queue_name) from query_err
+
+ return QueueStats(
+ queue_name=task_queue_name,
+ message_count=declare_result.message_count,
+ consumer_count=declare_result.consumer_count,
+ )
+
task_queue_stats(self, task_queue_name)
+
+
+ async
+
+
+Get the stats of a task queue.
+ +mognet/broker/amqp_broker.py
async def task_queue_stats(self, task_queue_name: str) -> QueueStats:
+ """
+ Get the stats of a task queue.
+ """
+
+ name = self._task_queue_name(task_queue_name)
+
+ # AMQP can close the Channel on us if we try accessing an object
+ # that doesn't exist. So, to avoid trouble when the same Channel
+ # is being used to consume a queue, use an ephemeral channel
+ # for this operation alone.
+ async with self._connection.channel() as channel:
+ try:
+ queue = await channel.get_queue(name, ensure=False)
+
+ declare_result = await queue.declare()
+ except (
+ aiormq.exceptions.ChannelNotFoundEntity,
+ aio_pika.exceptions.ChannelClosed,
+ ) as query_err:
+ raise QueueNotFound(task_queue_name) from query_err
+
+ return QueueStats(
+ queue_name=task_queue_name,
+ message_count=declare_result.message_count,
+ consumer_count=declare_result.consumer_count,
+ )
+
cli
+
+
+
+ special
+
+
+exceptions
+
+
+
+
+GracefulShutdown (BaseException)
+
+
+
+
+If this exception is raised from a coroutine, the Mognet app running from the CLI will be gracefully closed
+ + + + +main
+
+
+
+callback(app=<typer.models.ArgumentInfo object at 0x7fd849da6c50>, log_level=<typer.models.OptionInfo object at 0x7fd849da6ce0>, log_format=<typer.models.OptionInfo object at 0x7fd849da6d10>)
+
+
+Mognet CLI
+ +mognet/cli/main.py
@main.callback()
+def callback(
+ app: str = typer.Argument(..., help="App module to import"),
+ log_level: LogLevel = typer.Option("INFO", metavar="log-level"),
+ log_format: str = typer.Option(
+ "%(asctime)s:%(name)s:%(levelname)s:%(message)s", metavar="log-format"
+ ),
+):
+ """Mognet CLI"""
+
+ logging.basicConfig(
+ level=getattr(logging, log_level.value),
+ format=log_format,
+ )
+
+ app_instance = _get_app(app)
+ state["app_instance"] = app_instance
+
models
+
+
+
+
+LogLevel (Enum)
+
+
+
+
+nodes
+
+
+
+status(format=<typer.models.OptionInfo object at 0x7fd849dda2f0>, text_label_format=<typer.models.OptionInfo object at 0x7fd849dda320>, json_indent=<typer.models.OptionInfo object at 0x7fd849dda350>, poll=<typer.models.OptionInfo object at 0x7fd849dda380>, timeout=<typer.models.OptionInfo object at 0x7fd849dda3b0>)
+
+
+ async
+
+
+Query each node for their status
+ +mognet/cli/nodes.py
@group.command("status")
+@run_in_loop
+async def status(
+ format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
+ text_label_format: str = typer.Option(
+ "{name}(id={id!r}, state={state!r})",
+ metavar="text-label-format",
+ help="Label format for text format",
+ ),
+ json_indent: int = typer.Option(2, metavar="json-indent"),
+ poll: Optional[int] = typer.Option(
+ None,
+ metavar="poll",
+ help="Polling interval, in seconds (default=None)",
+ ),
+ timeout: int = typer.Option(
+ 30,
+ help="Timeout for querying nodes",
+ ),
+):
+ """Query each node for their status"""
+
+ async with state["app_instance"] as app:
+ while True:
+ each_node_status: List[StatusResponseMessage] = []
+
+ async def read_status():
+ async for node_status in app.get_current_status_of_nodes():
+ each_node_status.append(node_status)
+
+ try:
+ await asyncio.wait_for(read_status(), timeout=timeout)
+ except asyncio.TimeoutError:
+ pass
+
+ all_result_ids = set()
+
+ for node_status in each_node_status:
+ all_result_ids.update(node_status.payload.running_request_ids)
+
+ all_results_by_id = {
+ r.id: r
+ for r in await app.result_backend.get_many(
+ *all_result_ids,
+ )
+ if r is not None
+ }
+
+ report = _CliStatusReport()
+
+ for node_status in each_node_status:
+ running_requests = [
+ all_results_by_id[r]
+ for r in node_status.payload.running_request_ids
+ if r in all_results_by_id
+ ]
+ running_requests.sort(key=lambda r: r.created or now_utc())
+
+ report.node_status.append(
+ _CliStatusReport.NodeStatus(
+ node_id=node_status.node_id, running_requests=running_requests
+ )
+ )
+
+ if poll:
+ typer.clear()
+
+ if format == "text":
+ table_headers = ("Node name", "Running requests")
+
+ table_data = [
+ (
+ n.node_id,
+ "\n".join(
+ text_label_format.format(**r.dict())
+ for r in n.running_requests
+ )
+ or "(Empty)",
+ )
+ for n in report.node_status
+ ]
+
+ typer.echo(
+ f"{len(report.node_status)} nodes replied as of {datetime.now()}:"
+ )
+
+ typer.echo(tabulate.tabulate(table_data, headers=table_headers))
+
+ elif format == "json":
+ typer.echo(report.json(indent=json_indent, ensure_ascii=False))
+
+ if not poll:
+ break
+
+ await asyncio.sleep(poll)
+
queues
+
+
+
+purge(force=<typer.models.OptionInfo object at 0x7fd849ddad10>)
+
+
+ async
+
+
+Purge task and control queues
+ +mognet/cli/queues.py
@group.command("purge")
+@run_in_loop
+async def purge(force: bool = typer.Option(False)):
+ """Purge task and control queues"""
+
+ if not force:
+ typer.echo("Must pass --force")
+ raise typer.Exit(1)
+
+ async with state["app_instance"] as app:
+ await app.connect()
+
+ purged_task_counts = await app.purge_task_queues()
+ purged_control_count = await app.purge_control_queue()
+
+ typer.echo("Purged the following queues:")
+
+ for queue_name, count in purged_task_counts.items():
+ typer.echo(f"\t- {queue_name!r}: {count!r}")
+
+ typer.echo(f"Purged {purged_control_count!r} control messages")
+
run
+
+
+
+run(include_queues=<typer.models.OptionInfo object at 0x7fd849eb2530>, exclude_queues=<typer.models.OptionInfo object at 0x7fd849eb2110>)
+
+
+Run the app
+ +mognet/cli/run.py
@group.callback()
+def run(
+ include_queues: Optional[str] = typer.Option(
+ None,
+ metavar="include-queues",
+ help="Comma-separated list of the ONLY queues to listen on.",
+ ),
+ exclude_queues: Optional[str] = typer.Option(
+ None,
+ metavar="exclude-queues",
+ help="Comma-separated list of the ONLY queues to NOT listen on.",
+ ),
+):
+ """Run the app"""
+
+ app = state["app_instance"]
+
+ # Allow overriding the queues this app listens on.
+ queues = app.config.task_queues
+
+ if include_queues is not None:
+ queues.include = set(q.strip() for q in include_queues.split(","))
+
+ if exclude_queues is not None:
+ queues.exclude = set(q.strip() for q in exclude_queues.split(","))
+
+ queues.ensure_valid()
+
+ async def start():
+ async with app:
+ await app.start()
+
+ async def stop(_: AbstractEventLoop):
+ _log.info("Going to close app as part of a shut down")
+ await app.close()
+
+ pending_exception_to_raise = SystemExit(0)
+
+ def custom_exception_handler(loop: AbstractEventLoop, context: dict):
+ """See: https://docs.python.org/3/library/asyncio-eventloop.html#error-handling-api"""
+
+ nonlocal pending_exception_to_raise
+
+ exc = context.get("exception")
+
+ if isinstance(exc, GracefulShutdown):
+ _log.debug("Got GracefulShutdown")
+ elif isinstance(exc, BaseException):
+ pending_exception_to_raise = exc
+
+ _log.error(
+ "Unhandled exception; stopping loop: %r %r",
+ context.get("message"),
+ context,
+ exc_info=pending_exception_to_raise,
+ )
+
+ loop.stop()
+
+ loop = asyncio.get_event_loop()
+ loop.set_exception_handler(custom_exception_handler)
+
+ aiorun.run(
+ start(), loop=loop, stop_on_unhandled_errors=False, shutdown_callback=stop
+ )
+
+ if pending_exception_to_raise is not None:
+ raise pending_exception_to_raise
+
+ return 0
+
run_in_loop
+
+
+
+run_in_loop(f)
+
+
+Utility to run a click/typer command function in an event loop +(because they don't support it out of the box)
+ +mognet/cli/run_in_loop.py
tasks
+
+
+
+get(task_id=<typer.models.ArgumentInfo object at 0x7fd849da4c10>, include_value=<typer.models.OptionInfo object at 0x7fd849da4be0>)
+
+
+ async
+
+
+Get a task's details
+ +mognet/cli/tasks.py
@group.command("get")
+@run_in_loop
+async def get(
+ task_id: UUID = typer.Argument(
+ ...,
+ metavar="id",
+ help="Task ID to get",
+ ),
+ include_value: bool = typer.Option(
+ False,
+ metavar="include-value",
+ help="If passed, the task's result (or exception) will be printed",
+ ),
+):
+ """Get a task's details"""
+
+ async with state["app_instance"] as app:
+ res = await app.result_backend.get(task_id)
+
+ if res is None:
+ _log.warning("Request %r does not exist", task_id)
+ raise typer.Exit(1)
+
+ table_data = [
+ ("ID", res.id),
+ ("Name", res.name),
+ ("Arguments", res.request_kwargs_repr),
+ ("State", res.state),
+ ("Number of starts", res.number_of_starts),
+ ("Number of stops", res.number_of_stops),
+ ("Unexpected retry count", res.unexpected_retry_count),
+ ("Parent", res.parent_id),
+ ("Created at", res.created),
+ ("Started at", res.started),
+ ("Time in queue", res.queue_time),
+ ("Finished at", res.finished),
+ ("Runtime duration", res.duration),
+ ("Node ID", res.node_id),
+ ("Metadata", await res.get_metadata()),
+ ]
+
+ if include_value:
+ try:
+ value = await res.value.get_raw_value()
+
+ if isinstance(value, _ExceptionInfo):
+ table_data.append(("Error raised", value.traceback))
+ else:
+ table_data.append(("Result value", repr(value)))
+
+ except ResultValueLost:
+ table_data.append(("Result value", "<Lost>"))
+
+ print(tabulate.tabulate(table_data))
+
revoke(task_id=<typer.models.ArgumentInfo object at 0x7fd849da49a0>, force=<typer.models.OptionInfo object at 0x7fd849da74f0>)
+
+
+ async
+
+
+Revoke a task
+ +mognet/cli/tasks.py
@group.command("revoke")
+@run_in_loop
+async def revoke(
+ task_id: UUID = typer.Argument(
+ ...,
+ metavar="id",
+ help="Task ID to revoke",
+ ),
+ force: bool = typer.Option(
+ False,
+ metavar="force",
+ help="Attempt revoking anyway if the result is complete. Helps cleaning up cases where subtasks may have been spawned.",
+ ),
+):
+ """Revoke a task"""
+
+ async with state["app_instance"] as app:
+
+ res = await app.result_backend.get(task_id)
+
+ ret_code = 0
+
+ if res is None:
+ _log.warning("Request %r does not exist", task_id)
+ ret_code = 1
+
+ await app.revoke(task_id, force=force)
+
+ raise typer.Exit(ret_code)
+
tree(task_id=<typer.models.ArgumentInfo object at 0x7fd849da4970>, format=<typer.models.OptionInfo object at 0x7fd849da7670>, json_indent=<typer.models.OptionInfo object at 0x7fd849da49d0>, text_label_format=<typer.models.OptionInfo object at 0x7fd849da74c0>, max_depth=<typer.models.OptionInfo object at 0x7fd849da7700>, max_width=<typer.models.OptionInfo object at 0x7fd849da75b0>, poll=<typer.models.OptionInfo object at 0x7fd849da7520>)
+
+
+ async
+
+
+Get the tree (descendants) of a task
+ +mognet/cli/tasks.py
@group.command("tree")
+@run_in_loop
+async def tree(
+ task_id: UUID = typer.Argument(
+ ...,
+ metavar="id",
+ help="Task ID to get tree from",
+ ),
+ format: OutputFormat = typer.Option(OutputFormat.TEXT, metavar="format"),
+ json_indent: int = typer.Option(2, metavar="json-indent"),
+ text_label_format: str = typer.Option(
+ "{name}(id={id!r}, state={state!r})",
+ metavar="text-label-format",
+ help="Label format for text format",
+ ),
+ max_depth: int = typer.Option(3, metavar="max-depth"),
+ max_width: int = typer.Option(16, metavar="max-width"),
+ poll: Optional[int] = typer.Option(None, metavar="poll"),
+):
+ """Get the tree (descendants) of a task"""
+
+ async with state["app_instance"] as app:
+ while True:
+ result = await app.result_backend.get(task_id)
+
+ if result is None:
+ raise RuntimeError(f"Result for request id={task_id!r} does not exist")
+
+ _log.info("Building tree for result id=%r", result.id)
+
+ tree = await result.tree(max_depth=max_depth, max_width=max_width)
+
+ if poll:
+ typer.clear()
+
+ if format == "text":
+ t = treelib.Tree()
+
+ def build_tree(n: ResultTree, parent: Optional[ResultTree] = None):
+ t.create_node(
+ tag=text_label_format.format(**n.dict()),
+ identifier=n.result.id,
+ parent=None if parent is None else parent.result.id,
+ )
+
+ for c in n.children:
+ build_tree(c, parent=n)
+
+ build_tree(tree)
+
+ t.show()
+
+ if format == "json":
+ print(tree.json(indent=json_indent, ensure_ascii=False))
+
+ if not poll:
+ break
+
+ await asyncio.sleep(poll)
+
context
+
+
+
+ special
+
+
+context
+
+
+
+
+Context
+
+
+
+Context for a request.
+Allows access to the App instance, task state, +and the request that is part of this task execution.
+ +mognet/context/context.py
class Context:
+ """
+ Context for a request.
+
+ Allows access to the App instance, task state,
+ and the request that is part of this task execution.
+ """
+
+ app: "App"
+
+ state: "State"
+
+ request: "Request"
+
+ _dependencies: Set[UUID]
+
+ def __init__(
+ self,
+ app: "App",
+ request: "Request",
+ state: "State",
+ worker: "Worker",
+ ):
+ self.app = app
+ self.state = state
+ self.request = request
+ self._worker = worker
+
+ self._dependencies = set()
+
+ self.create_request = self.app.create_request
+
+ async def submit(self, request: "Request"):
+ """
+ Submits a new request as part of this one.
+
+ The difference from this method to the one defined in the `App` class
+ is that this one will submit the new request as a child request of
+ the one that's a part of this `Context` instance. This allows
+ the subrequests to be cancelled if the parent is also cancelled.
+ """
+ return await self.app.submit(request, self)
+
+ @overload
+ async def run(self, request: Request[_Return]) -> _Return:
+ """
+ Submits a Request to be run as part of this one (see `submit`), and waits for the result
+ """
+ ...
+
+ @overload
+ async def run(
+ self,
+ request: Callable[Concatenate["Context", _P], _Return],
+ *args: _P.args,
+ **kwargs: _P.kwargs
+ ) -> _Return:
+ """
+ Short-hand method for creating a Request from a function decorated with `@task`,
+ (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
+
+ This overload is for documenting non-async def functions.
+ """
+ ...
+
+ # This overload unwraps the Awaitable object
+ @overload
+ async def run(
+ self,
+ request: Callable[Concatenate["Context", _P], Awaitable[_Return]],
+ *args: _P.args,
+ **kwargs: _P.kwargs
+ ) -> _Return:
+ """
+ Short-hand method for creating a Request from a function decorated with `@task`,
+ (see `create_request`), submitting it (see `submit`) and waiting for the result (see `run(Request)`).
+
+ This overload is for documenting async def functions.
+ """
+ ...
+
+ async def run(self, request, *args, **kwargs):
+ """
+ Submits and runs a new request as part of this one.
+
+ See `submit` for the difference between this and the equivalent
+ `run` method on the `App` class.
+ """
+
+ if not isinstance(request, Request):
+ request = self.create_request(request, *args, **kwargs)
+
+ cancelled = False
+ try:
+ had_dependencies = bool(self._dependencies)
+
+ self._dependencies.add(request.id)
+
+ # If we transition from having no dependencies
+ # to having some, then we should suspend.
+ if not had_dependencies and self._dependencies:
+ await asyncio.shield(self._suspend())
+
+ self._log_dependencies()
+
+ return await self.app.run(request, self)
+ except asyncio.CancelledError:
+ cancelled = True
+ finally:
+ self._dependencies.remove(request.id)
+
+ self._log_dependencies()
+
+ if not self._dependencies and not cancelled:
+ await asyncio.shield(self._resume())
+
+ def _log_dependencies(self):
+ _log.debug(
+ "Task %r is waiting on %r dependencies",
+ self.request,
+ len(self._dependencies),
+ )
+
+ async def gather(
+ self, *results_or_ids: Union["Result", UUID], return_exceptions: bool = False
+ ) -> List[Any]:
+ results = []
+ cancelled = False
+ try:
+ for result in results_or_ids:
+ if isinstance(result, UUID):
+ result = await self.app.result_backend.get(result)
+
+ results.append(result)
+
+ # If we transition from having no dependencies
+ # to having some, then we should suspend.
+ had_dependencies = bool(self._dependencies)
+ self._dependencies.update(r.id for r in results)
+
+ self._log_dependencies()
+
+ if not had_dependencies and self._dependencies:
+ await asyncio.shield(self._suspend())
+
+ return await asyncio.gather(*results, return_exceptions=return_exceptions)
+ except asyncio.CancelledError:
+ cancelled = True
+ raise
+ finally:
+ self._dependencies.difference_update(r.id for r in results)
+
+ self._log_dependencies()
+
+ if not self._dependencies and not cancelled:
+ await asyncio.shield(self._resume())
+
+ @overload
+ def get_service(
+ self, func: Type[ClassService[_Return]], *args, **kwargs
+ ) -> _Return:
+ ...
+
+ @overload
+ def get_service(
+ self,
+ func: Callable[Concatenate["Context", _P], _Return],
+ *args: _P.args,
+ **kwargs: _P.kwargs
+ ) -> _Return:
+ ...
+
+ def get_service(self, func, *args, **kwargs):
+ """
+ Get a service to use in the task function.
+ This can be used for dependency injection purposes.
+ """
+
+ if inspect.isclass(func) and issubclass(func, ClassService):
+ if func not in self.app.services:
+ # This cast() is only here to silence Pylance (because it thinks the class is abstract)
+ instance: ClassService = cast(Any, func)(self.app.config)
+ self.app.services[func] = instance.__enter__()
+
+ svc = self.app.services[func]
+ else:
+ svc = self.app.services.setdefault(func, func)
+
+ return svc(self, *args, **kwargs)
+
+ async def _suspend(self):
+ _log.debug("Suspending %r", self.request)
+
+ result = await self.get_result()
+
+ if result.state == ResultState.RUNNING:
+ await result.suspend()
+
+ await self._worker.add_waiting_task(self.request.id)
+
+ async def get_result(self):
+ """
+ Gets the Result associated with this task.
+
+ WARNING: Do not `await` the returned Result instance! You will run
+ into a deadlock (you will be awaiting yourself)
+ """
+ result = await self.app.result_backend.get(self.request.id)
+
+ if result is None:
+ raise ResultLost(self.request.id)
+
+ return result
+
+ def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
+ """
+ NOTE: ONLY TO BE USED WITH SYNC TASKS!
+
+ Utility function that will run the coroutine in the app's event loop
+ in a thread-safe way.
+
+ In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`
+
+ Use as follows:
+
+ ```
+ context.call_sync(context.submit(...))
+ ```
+ """
+ return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()
+
+ async def set_metadata(self, **kwargs: Any):
+ """
+ Update metadata on the Result associated with the current task.
+ """
+
+ result = await self.get_result()
+ return await result.set_metadata(**kwargs)
+
+ async def _resume(self):
+ _log.debug("Resuming %r", self.request)
+
+ result = await self.get_result()
+
+ if result.state == ResultState.SUSPENDED:
+ await result.resume()
+
+ await self._worker.remove_suspended_task(self.request.id)
+
call_threadsafe(self, coro)
+
+
+NOTE: ONLY TO BE USED WITH SYNC TASKS!
+Utility function that will run the coroutine in the app's event loop +in a thread-safe way.
+In reality this is a wrapper for asyncio.run_coroutine_threadsafe(...)
Use as follows:
+ + +mognet/context/context.py
def call_threadsafe(self, coro: Awaitable[_Return]) -> _Return:
+ """
+ NOTE: ONLY TO BE USED WITH SYNC TASKS!
+
+ Utility function that will run the coroutine in the app's event loop
+ in a thread-safe way.
+
+ In reality this is a wrapper for `asyncio.run_coroutine_threadsafe(...)`
+
+ Use as follows:
+
+ ```
+ context.call_sync(context.submit(...))
+ ```
+ """
+ return asyncio.run_coroutine_threadsafe(coro, loop=self.app.loop).result()
+
get_result(self)
+
+
+ async
+
+
+Gets the Result associated with this task.
+WARNING: Do not await
the returned Result instance! You will run
+into a deadlock (you will be awaiting yourself)
mognet/context/context.py
async def get_result(self):
+ """
+ Gets the Result associated with this task.
+
+ WARNING: Do not `await` the returned Result instance! You will run
+ into a deadlock (you will be awaiting yourself)
+ """
+ result = await self.app.result_backend.get(self.request.id)
+
+ if result is None:
+ raise ResultLost(self.request.id)
+
+ return result
+
get_service(self, func, *args, **kwargs)
+
+
+Get a service to use in the task function. +This can be used for dependency injection purposes.
+ +mognet/context/context.py
def get_service(self, func, *args, **kwargs):
+ """
+ Get a service to use in the task function.
+ This can be used for dependency injection purposes.
+ """
+
+ if inspect.isclass(func) and issubclass(func, ClassService):
+ if func not in self.app.services:
+ # This cast() is only here to silence Pylance (because it thinks the class is abstract)
+ instance: ClassService = cast(Any, func)(self.app.config)
+ self.app.services[func] = instance.__enter__()
+
+ svc = self.app.services[func]
+ else:
+ svc = self.app.services.setdefault(func, func)
+
+ return svc(self, *args, **kwargs)
+
run(self, request, *args, **kwargs)
+
+
+ async
+
+
+Submits and runs a new request as part of this one.
+See submit
for the difference between this and the equivalent
+run
method on the App
class.
mognet/context/context.py
async def run(self, request, *args, **kwargs):
+ """
+ Submits and runs a new request as part of this one.
+
+ See `submit` for the difference between this and the equivalent
+ `run` method on the `App` class.
+ """
+
+ if not isinstance(request, Request):
+ request = self.create_request(request, *args, **kwargs)
+
+ cancelled = False
+ try:
+ had_dependencies = bool(self._dependencies)
+
+ self._dependencies.add(request.id)
+
+ # If we transition from having no dependencies
+ # to having some, then we should suspend.
+ if not had_dependencies and self._dependencies:
+ await asyncio.shield(self._suspend())
+
+ self._log_dependencies()
+
+ return await self.app.run(request, self)
+ except asyncio.CancelledError:
+ cancelled = True
+ finally:
+ self._dependencies.remove(request.id)
+
+ self._log_dependencies()
+
+ if not self._dependencies and not cancelled:
+ await asyncio.shield(self._resume())
+
set_metadata(self, **kwargs)
+
+
+ async
+
+
+Update metadata on the Result associated with the current task.
+ + +submit(self, request)
+
+
+ async
+
+
+Submits a new request as part of this one.
+The difference from this method to the one defined in the App
class
+is that this one will submit the new request as a child request of
+the one that's a part of this Context
instance. This allows
+the subrequests to be cancelled if the parent is also cancelled.
mognet/context/context.py
async def submit(self, request: "Request"):
+ """
+ Submits a new request as part of this one.
+
+ The difference from this method to the one defined in the `App` class
+ is that this one will submit the new request as a child request of
+ the one that's a part of this `Context` instance. This allows
+ the subrequests to be cancelled if the parent is also cancelled.
+ """
+ return await self.app.submit(request, self)
+
decorators
+
+
+
+ special
+
+
+task_decorator
+
+
+
+task(*, name=None)
+
+
+Register a function as a task that can be run.
+The name argument is recommended, but not required. It is used as an identifier +for which task to run when creating Request objects.
+If the name is not provided, the function's full name (module + name) is used instead. +Bear in mind that this means that if you rename the module or the function, things may break +during rolling upgrades.
+ +mognet/decorators/task_decorator.py
def task(*, name: Optional[str] = None):
+ """
+ Register a function as a task that can be run.
+
+ The name argument is recommended, but not required. It is used as an identifier
+ for which task to run when creating Request objects.
+
+ If the name is not provided, the function's full name (module + name) is used instead.
+ Bear in mind that this means that if you rename the module or the function, things may break
+ during rolling upgrades.
+ """
+
+ def task_decorator(t: _T) -> _T:
+ reg = task_registry.get(None)
+
+ if reg is None:
+ _log.debug("No global task registry set. Creating one")
+
+ reg = TaskRegistry()
+ reg.register_globally()
+
+ reg.add_task_function(cast(Callable, t), name=name)
+
+ return t
+
+ return task_decorator
+
exceptions
+
+
+
+ special
+
+
+base_exceptions
+
+
+
+
+ConnectionError (MognetError)
+
+
+
+
+
+CouldNotSubmit (MognetError)
+
+
+
+
+
+ImproperlyConfigured (MognetError)
+
+
+
+
+
+MognetError (Exception)
+
+
+
+
+
+NotConnected (ConnectionError)
+
+
+
+
+result_exceptions
+
+
+
+
+ResultLost (ResultError)
+
+
+
+
+Raised when the result itself was lost +(potentially due to key eviction)
+ +mognet/exceptions/result_exceptions.py
class ResultLost(ResultError):
+ """
+ Raised when the result itself was lost
+ (potentially due to key eviction)
+ """
+
+ def __init__(self, result_id: UUID) -> None:
+ super().__init__(result_id)
+ self.result_id = result_id
+
+ def __str__(self) -> str:
+ return f"Result id={self.result_id!r} lost"
+
+ResultValueLost (ResultError)
+
+
+
+
+Raised when the value for a result was lost +(potentially due to key eviction)
+ +mognet/exceptions/result_exceptions.py
class ResultValueLost(ResultError):
+ """
+ Raised when the value for a result was lost
+ (potentially due to key eviction)
+ """
+
+ def __init__(self, result_id: UUID) -> None:
+ super().__init__(result_id)
+ self.result_id = result_id
+
+ def __str__(self) -> str:
+ return f"Value for result id={self.result_id!r} lost"
+
+Revoked (ResultFailed)
+
+
+
+
+Raised when a task is revoked, either by timing out, or manual revoking.
+ +mognet/exceptions/result_exceptions.py
task_exceptions
+
+
+
+
+InvalidErrorInfo (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+
+InvalidTaskArguments (Exception)
+
+
+
+
+Raised when the arguments to a task could not be validated.
+ +mognet/exceptions/task_exceptions.py
class InvalidTaskArguments(Exception):
+ """
+ Raised when the arguments to a task could not be validated.
+ """
+
+ def __init__(self, errors: List[InvalidErrorInfo]) -> None:
+ super().__init__(errors)
+ self.errors = errors
+
+ @classmethod
+ def from_validation_error(cls, validation_error: ValidationError):
+ return cls([InvalidErrorInfo.parse_obj(e) for e in validation_error.errors()])
+
+Pause (Exception)
+
+
+
+
+Tasks may raise this when they want to stop +execution and have their message return to the Task Broker.
+Once the message is retrieved again, task execution will resume.
+ + + + +too_many_retries
+
+
+
+
+TooManyRetries (MognetError)
+
+
+
+
+Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).
+The number of retries for any particular task can be configured through the App's
+configuration, in max_retries
.
mognet/exceptions/too_many_retries.py
class TooManyRetries(MognetError):
+ """
+ Raised when a task is retried too many times due to unforeseen errors (e.g., SIGKILL).
+
+ The number of retries for any particular task can be configured through the App's
+ configuration, in `max_retries`.
+ """
+
+ def __init__(
+ self,
+ request_id: UUID,
+ actual_retries: int,
+ max_retries: int,
+ ) -> None:
+ super().__init__(request_id, actual_retries, max_retries)
+
+ self.request_id = request_id
+ self.max_retries = max_retries
+ self.actual_retries = actual_retries
+
+ def __str__(self) -> str:
+ return f"Task id={self.request_id!r} has been retried {self.actual_retries!r} times, which is more than the limit of {self.max_retries!r}"
+
middleware
+
+
+
+ special
+
+
+middleware
+
+
+
+
+Middleware (Protocol)
+
+
+
+
+Defines middleware that can hook into different parts of a Mognet App's lifecycle.
+ +mognet/middleware/middleware.py
class Middleware(Protocol):
+ """
+ Defines middleware that can hook into different parts of a Mognet App's lifecycle.
+ """
+
+ async def on_app_starting(self, app: "App") -> None:
+ """
+ Called when the app is starting, but before it starts connecting to the backends.
+
+ For example, you can use this for some early initialization of singleton-type objects in your app.
+ """
+
+ async def on_app_started(self, app: "App") -> None:
+ """
+ Called when the app has started.
+
+ For example, you can use this for some early initialization of singleton-type objects in your app.
+ """
+
+ async def on_app_stopping(self, app: "App") -> None:
+ """
+ Called when the app is preparing to stop, but before it starts disconnecting.
+
+ For example, you can use this for cleaning up objects that were previously set up.
+ """
+
+ async def on_app_stopped(self, app: "App") -> None:
+ """
+ Called when the app has stopped.
+
+ For example, you can use this for cleaning up objects that were previously set up.
+ """
+
+ async def on_task_starting(self, context: "Context"):
+ """
+ Called when a task is starting.
+
+ You can use this, for example, to track a task on a database.
+ """
+
+ async def on_task_completed(
+ self, result: "Result", context: Optional["Context"] = None
+ ):
+ """
+ Called when a task has completed it's execution.
+
+ You can use this, for example, to track a task on a database.
+ """
+
+ async def on_request_submitting(
+ self, request: "Request", context: Optional["Context"] = None
+ ):
+ """
+ Called when a Request object is going to be submitted to the Broker.
+
+ You can use this, for example, both to track the task on a database, or to modify
+ the Request object (e.g., to modify arguments, or set metadata).
+ """
+
+ async def on_running_task_count_changed(self, running_task_count: int):
+ """
+ Called when the Worker's task count changes.
+
+ This can be used to determine when the Worker has nothing to do.
+ """
+
on_app_started(self, app)
+
+
+ async
+
+
+Called when the app has started.
+For example, you can use this for some early initialization of singleton-type objects in your app.
+ + +on_app_starting(self, app)
+
+
+ async
+
+
+Called when the app is starting, but before it starts connecting to the backends.
+For example, you can use this for some early initialization of singleton-type objects in your app.
+ +mognet/middleware/middleware.py
on_app_stopped(self, app)
+
+
+ async
+
+
+Called when the app has stopped.
+For example, you can use this for cleaning up objects that were previously set up.
+ + +on_app_stopping(self, app)
+
+
+ async
+
+
+Called when the app is preparing to stop, but before it starts disconnecting.
+For example, you can use this for cleaning up objects that were previously set up.
+ + +on_request_submitting(self, request, context=None)
+
+
+ async
+
+
+Called when a Request object is going to be submitted to the Broker.
+You can use this, for example, both to track the task on a database, or to modify +the Request object (e.g., to modify arguments, or set metadata).
+ +mognet/middleware/middleware.py
async def on_request_submitting(
+ self, request: "Request", context: Optional["Context"] = None
+):
+ """
+ Called when a Request object is going to be submitted to the Broker.
+
+ You can use this, for example, both to track the task on a database, or to modify
+ the Request object (e.g., to modify arguments, or set metadata).
+ """
+
on_running_task_count_changed(self, running_task_count)
+
+
+ async
+
+
+Called when the Worker's task count changes.
+This can be used to determine when the Worker has nothing to do.
+ + +on_task_completed(self, result, context=None)
+
+
+ async
+
+
+Called when a task has completed it's execution.
+You can use this, for example, to track a task on a database.
+ + +on_task_starting(self, context)
+
+
+ async
+
+
+Called when a task is starting.
+You can use this, for example, to track a task on a database.
+ + +model
+
+
+
+ special
+
+
+result
+
+
+
+
+Result (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+Represents the result of executing a Request
.
It contains, along the return value (or raised exception), +information on the resulting state, how many times it started, +and timing information.
+ +mognet/model/result.py
class Result(BaseModel):
+ """
+ Represents the result of executing a [`Request`][mognet.Request].
+
+ It contains, along the return value (or raised exception),
+ information on the resulting state, how many times it started,
+ and timing information.
+ """
+
+ id: UUID
+ name: Optional[str] = None
+ state: ResultState = ResultState.PENDING
+
+ number_of_starts: int = 0
+ number_of_stops: int = 0
+
+ parent_id: Optional[UUID] = None
+
+ created: Optional[datetime]
+ started: Optional[datetime]
+ finished: Optional[datetime]
+
+ node_id: Optional[str]
+
+ request_kwargs_repr: Optional[str]
+
+ _backend: "BaseResultBackend" = PrivateAttr()
+ _children: Optional[ResultChildren] = PrivateAttr()
+ _value: Optional[ResultValue] = PrivateAttr()
+
+ def __init__(self, backend: "BaseResultBackend", **data) -> None:
+ super().__init__(**data)
+ self._backend = backend
+ self._children = None
+ self._value = None
+
+ @property
+ def children(self) -> ResultChildren:
+ """Get an iterator on the children of this Result. Non-recursive."""
+
+ if self._children is None:
+ self._children = ResultChildren(self, self._backend)
+
+ return self._children
+
+ @property
+ def value(self) -> ResultValue:
+ """Get information about the value of this Result"""
+ if self._value is None:
+ self._value = ResultValue(self, self._backend)
+
+ return self._value
+
+ @property
+ def duration(self) -> Optional[timedelta]:
+ """
+ Returns the time it took to complete this result.
+
+ Returns None if the result did not start or finish.
+ """
+ if not self.started or not self.finished:
+ return None
+
+ return self.finished - self.started
+
+ @property
+ def queue_time(self) -> Optional[timedelta]:
+ """
+ Returns the time it took to start the task associated to this result.
+
+ Returns None if the task did not start.
+ """
+ if not self.created or not self.started:
+ return None
+
+ return self.started - self.created
+
+ @property
+ def done(self):
+ """
+ True if the result is in a terminal state (e.g., SUCCESS, FAILURE).
+ See `READY_STATES`.
+ """
+ return self.state in READY_STATES
+
+ @property
+ def successful(self):
+ """True if the result was successful."""
+ return self.state in SUCCESS_STATES
+
+ @property
+ def failed(self):
+ """True if the result failed or was revoked."""
+ return self.state in ERROR_STATES
+
+ @property
+ def revoked(self):
+ """True if the result was revoked."""
+ return self.state == ResultState.REVOKED
+
+ @property
+ def unexpected_retry_count(self) -> int:
+ """
+ Return the number of times the task associated with this result was retried
+ as a result of an unexpected error, such as a SIGKILL.
+ """
+ return max(0, self.number_of_starts - self.number_of_stops)
+
+ async def wait(self, *, timeout: Optional[float] = None, poll: float = 0.1) -> None:
+ """Wait for the task associated with this result to finish."""
+ updated_result = await self._backend.wait(self.id, timeout=timeout, poll=poll)
+
+ await self._refresh(updated_result)
+
+ async def revoke(self) -> "Result":
+ """
+ Revoke this Result.
+
+ This shouldn't be called directly, use the method on the App class instead,
+ as that will also revoke the children, recursively.
+ """
+ self.state = ResultState.REVOKED
+ self.number_of_stops += 1
+ self.finished = now_utc()
+ await self._backend.set(self.id, self)
+ return self
+
+ async def get(self) -> Any:
+ """
+ Gets the value of this `Result` instance.
+
+ Raises `ResultNotReady` if it's not ready yet.
+ Raises any stored exception if the result failed
+
+ Returns the stored value otherwise.
+
+ Use `value.get_raw_value()` if you want access to the raw value.
+ Call `wait` to wait for the value to be available.
+
+ Optionally, `await` the result instance.
+ """
+
+ if not self.done:
+ raise ResultNotReady()
+
+ value = await self.value.get_raw_value()
+
+ if self.state == ResultState.REVOKED:
+ raise Revoked(self)
+
+ if self.failed:
+ if value is None:
+ value = ResultFailed(self)
+
+ # Re-hydrate exceptions.
+ if isinstance(value, _ExceptionInfo):
+ raise value.exception
+
+ if not isinstance(value, BaseException):
+ value = Exception(value)
+
+ raise value
+
+ return value
+
+ async def set_result(
+ self,
+ value: Any,
+ state: ResultState = ResultState.SUCCESS,
+ ) -> "Result":
+ """
+ Set this Result to a success state, and store the value
+ which will be return when one `get()`s this Result's value.
+ """
+ await self.value.set_raw_value(value)
+
+ self.finished = now_utc()
+
+ self.state = state
+ self.number_of_stops += 1
+
+ await self._update()
+
+ return self
+
+ async def set_error(
+ self,
+ exc: BaseException,
+ state: ResultState = ResultState.FAILURE,
+ ) -> "Result":
+ """
+ Set this Result to an error state, and store the exception
+ which will be raised if one attempts to `get()` this Result's
+ value.
+ """
+
+ _log.debug("Setting result id=%r to %r", self.id, state)
+
+ await self.value.set_raw_value(exc)
+
+ self.finished = now_utc()
+
+ self.state = state
+ self.number_of_stops += 1
+
+ await self._update()
+
+ return self
+
+ async def start(self, *, node_id: Optional[str] = None) -> "Result":
+ """
+ Sets this `Result` as RUNNING, and logs the event.
+ """
+ self.started = now_utc()
+ self.node_id = node_id
+
+ self.state = ResultState.RUNNING
+ self.number_of_starts += 1
+
+ await self._update()
+
+ return self
+
+ async def resume(self, *, node_id: Optional[str] = None) -> "Result":
+ if node_id is not None:
+ self.node_id = node_id
+
+ self.state = ResultState.RUNNING
+ self.number_of_starts += 1
+
+ await self._update()
+
+ return self
+
+ async def suspend(self) -> "Result":
+ """
+ Sets this `Result` as SUSPENDED, and logs the event.
+ """
+
+ self.state = ResultState.SUSPENDED
+ self.number_of_stops += 1
+
+ await self._update()
+
+ return self
+
+ async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
+ """
+ Gets the tree of this result.
+
+ :param max_depth: The maximum depth of the tree that's to be generated.
+ This filters out results whose recursion levels are greater than it.
+ """
+ from .result_tree import ResultTree
+
+ async def get_tree(result: Result, depth=1):
+ _log.debug(
+ "Getting tree of result id=%r, depth=%r max_depth=%r",
+ result.id,
+ depth,
+ max_depth,
+ )
+
+ node = ResultTree(result=result, children=[])
+
+ if depth >= max_depth and (await result.children.count()):
+ _log.warning(
+ "Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
+ result.id,
+ depth,
+ max_depth,
+ )
+ return node
+
+ children_count = await result.children.count()
+ if children_count > max_width:
+ _log.warning(
+ "Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
+ result.id,
+ children_count,
+ max_width,
+ )
+
+ async for child in result.children.iter_instances(count=max_width):
+ node.children.append(await get_tree(child, depth=depth + 1))
+
+ node.children.sort(key=lambda r: r.result.created or now_utc())
+
+ return node
+
+ return await get_tree(self, depth=1)
+
+ async def get_metadata(self) -> Dict[str, Any]:
+ """Get the metadata associated with this Result."""
+ return await self._backend.get_metadata(self.id)
+
+ async def set_metadata(self, **kwargs: Any) -> None:
+ """Set metadata on this Result."""
+ await self._backend.set_metadata(self.id, **kwargs)
+
+ async def _refresh(self, updated_result: Optional["Result"] = None):
+ updated_result = updated_result or await self._backend.get(self.id)
+
+ if updated_result is None:
+ raise RuntimeError("Result no longer present")
+
+ for k, v in updated_result.__dict__.items():
+ if k == "id":
+ continue
+
+ setattr(self, k, v)
+
+ async def _update(self):
+ await self._backend.set(self.id, self)
+
+ def __repr__(self):
+ v = f"Result[{self.name or 'unknown'}, id={self.id!r}, state={self.state!r}]"
+
+ if self.request_kwargs_repr is not None:
+ v += f"({self.request_kwargs_repr})"
+
+ return v
+
+ # Implemented for asyncio's `await` functionality.
+ def __hash__(self) -> int:
+ return hash(f"Result_{self.id}")
+
+ def __await__(self):
+ yield from self.wait().__await__()
+ value = yield from self.get().__await__()
+ return value
+
+ async def delete(self, include_children: bool = True):
+ """
+ Delete this Result from the backend.
+
+ By default, this will delete children too.
+ """
+ await self._backend.delete(self.id, include_children=include_children)
+
+ async def set_ttl(self, ttl: timedelta, include_children: bool = True):
+ """
+ Set TTL on this Result.
+
+ By default, this will set it on the children too.
+ """
+ await self._backend.set_ttl(self.id, ttl, include_children=include_children)
+
children: ResultChildren
+
+
+ property
+ readonly
+
+
+Get an iterator on the children of this Result. Non-recursive.
+done
+
+
+ property
+ readonly
+
+
+True if the result is in a terminal state (e.g., SUCCESS, FAILURE).
+See READY_STATES
.
duration: Optional[datetime.timedelta]
+
+
+ property
+ readonly
+
+
+Returns the time it took to complete this result.
+Returns None if the result did not start or finish.
+failed
+
+
+ property
+ readonly
+
+
+True if the result failed or was revoked.
+queue_time: Optional[datetime.timedelta]
+
+
+ property
+ readonly
+
+
+Returns the time it took to start the task associated to this result.
+Returns None if the task did not start.
+revoked
+
+
+ property
+ readonly
+
+
+True if the result was revoked.
+successful
+
+
+ property
+ readonly
+
+
+True if the result was successful.
+unexpected_retry_count: int
+
+
+ property
+ readonly
+
+
+Return the number of times the task associated with this result was retried +as a result of an unexpected error, such as a SIGKILL.
+value: ResultValue
+
+
+ property
+ readonly
+
+
+Get information about the value of this Result
+__hash__(self)
+
+
+ special
+
+
+__repr__(self)
+
+
+ special
+
+
+delete(self, include_children=True)
+
+
+ async
+
+
+Delete this Result from the backend.
+By default, this will delete children too.
+ + +get(self)
+
+
+ async
+
+
+Gets the value of this Result
instance.
Raises ResultNotReady
if it's not ready yet.
+Raises any stored exception if the result failed
Returns the stored value otherwise.
+Use value.get_raw_value()
if you want access to the raw value.
+Call wait
to wait for the value to be available.
Optionally, await
the result instance.
mognet/model/result.py
async def get(self) -> Any:
+ """
+ Gets the value of this `Result` instance.
+
+ Raises `ResultNotReady` if it's not ready yet.
+ Raises any stored exception if the result failed
+
+ Returns the stored value otherwise.
+
+ Use `value.get_raw_value()` if you want access to the raw value.
+ Call `wait` to wait for the value to be available.
+
+ Optionally, `await` the result instance.
+ """
+
+ if not self.done:
+ raise ResultNotReady()
+
+ value = await self.value.get_raw_value()
+
+ if self.state == ResultState.REVOKED:
+ raise Revoked(self)
+
+ if self.failed:
+ if value is None:
+ value = ResultFailed(self)
+
+ # Re-hydrate exceptions.
+ if isinstance(value, _ExceptionInfo):
+ raise value.exception
+
+ if not isinstance(value, BaseException):
+ value = Exception(value)
+
+ raise value
+
+ return value
+
get_metadata(self)
+
+
+ async
+
+
+revoke(self)
+
+
+ async
+
+
+Revoke this Result.
+This shouldn't be called directly, use the method on the App class instead, +as that will also revoke the children, recursively.
+ +mognet/model/result.py
async def revoke(self) -> "Result":
+ """
+ Revoke this Result.
+
+ This shouldn't be called directly, use the method on the App class instead,
+ as that will also revoke the children, recursively.
+ """
+ self.state = ResultState.REVOKED
+ self.number_of_stops += 1
+ self.finished = now_utc()
+ await self._backend.set(self.id, self)
+ return self
+
set_error(self, exc, state='FAILURE')
+
+
+ async
+
+
+Set this Result to an error state, and store the exception
+which will be raised if one attempts to get()
this Result's
+value.
mognet/model/result.py
async def set_error(
+ self,
+ exc: BaseException,
+ state: ResultState = ResultState.FAILURE,
+) -> "Result":
+ """
+ Set this Result to an error state, and store the exception
+ which will be raised if one attempts to `get()` this Result's
+ value.
+ """
+
+ _log.debug("Setting result id=%r to %r", self.id, state)
+
+ await self.value.set_raw_value(exc)
+
+ self.finished = now_utc()
+
+ self.state = state
+ self.number_of_stops += 1
+
+ await self._update()
+
+ return self
+
set_metadata(self, **kwargs)
+
+
+ async
+
+
+set_result(self, value, state='SUCCESS')
+
+
+ async
+
+
+Set this Result to a success state, and store the value
+which will be return when one get()
s this Result's value.
mognet/model/result.py
async def set_result(
+ self,
+ value: Any,
+ state: ResultState = ResultState.SUCCESS,
+) -> "Result":
+ """
+ Set this Result to a success state, and store the value
+ which will be return when one `get()`s this Result's value.
+ """
+ await self.value.set_raw_value(value)
+
+ self.finished = now_utc()
+
+ self.state = state
+ self.number_of_stops += 1
+
+ await self._update()
+
+ return self
+
set_ttl(self, ttl, include_children=True)
+
+
+ async
+
+
+Set TTL on this Result.
+By default, this will set it on the children too.
+ + +start(self, *, node_id=None)
+
+
+ async
+
+
+Sets this Result
as RUNNING, and logs the event.
mognet/model/result.py
suspend(self)
+
+
+ async
+
+
+Sets this Result
as SUSPENDED, and logs the event.
tree(self, max_depth=3, max_width=500)
+
+
+ async
+
+
+Gets the tree of this result.
+:param max_depth: The maximum depth of the tree that's to be generated. + This filters out results whose recursion levels are greater than it.
+ +mognet/model/result.py
async def tree(self, max_depth: int = 3, max_width: int = 500) -> "ResultTree":
+ """
+ Gets the tree of this result.
+
+ :param max_depth: The maximum depth of the tree that's to be generated.
+ This filters out results whose recursion levels are greater than it.
+ """
+ from .result_tree import ResultTree
+
+ async def get_tree(result: Result, depth=1):
+ _log.debug(
+ "Getting tree of result id=%r, depth=%r max_depth=%r",
+ result.id,
+ depth,
+ max_depth,
+ )
+
+ node = ResultTree(result=result, children=[])
+
+ if depth >= max_depth and (await result.children.count()):
+ _log.warning(
+ "Result id=%r has %r or more levels of children, which is more than the limit of %r. Results will be truncated",
+ result.id,
+ depth,
+ max_depth,
+ )
+ return node
+
+ children_count = await result.children.count()
+ if children_count > max_width:
+ _log.warning(
+ "Result id=%r has %r children, which is more than the limit of %r. Results will be truncated",
+ result.id,
+ children_count,
+ max_width,
+ )
+
+ async for child in result.children.iter_instances(count=max_width):
+ node.children.append(await get_tree(child, depth=depth + 1))
+
+ node.children.sort(key=lambda r: r.result.created or now_utc())
+
+ return node
+
+ return await get_tree(self, depth=1)
+
wait(self, *, timeout=None, poll=0.1)
+
+
+ async
+
+
+Wait for the task associated with this result to finish.
+ +mognet/model/result.py
+ResultChildren
+
+
+
+The children of a Result.
+ +mognet/model/result.py
class ResultChildren:
+ """The children of a Result."""
+
+ def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
+ self._result = result
+ self._backend = backend
+
+ async def count(self) -> int:
+ """The number of children."""
+ return await self._backend.get_children_count(self._result.id)
+
+ def iter_ids(self, *, count: Optional[int] = None) -> AsyncGenerator[UUID, None]:
+ """Iterate the IDs of the children, optionally limited to a set count."""
+ return self._backend.iterate_children_ids(self._result.id, count=count)
+
+ def iter_instances(
+ self, *, count: Optional[int] = None
+ ) -> AsyncGenerator["Result", None]:
+ """Iterate the instances of the children, optionally limited to a set count."""
+ return self._backend.iterate_children(self._result.id, count=count)
+
+ async def add(self, *children_ids: UUID):
+ """For internal use."""
+ await self._backend.add_children(self._result.id, *children_ids)
+
add(self, *children_ids)
+
+
+ async
+
+
+count(self)
+
+
+ async
+
+
+iter_ids(self, *, count=None)
+
+
+Iterate the IDs of the children, optionally limited to a set count.
+ + +iter_instances(self, *, count=None)
+
+
+Iterate the instances of the children, optionally limited to a set count.
+ + +
+ResultValue
+
+
+
+Represents information about the value of a Result.
+ +mognet/model/result.py
class ResultValue:
+ """
+ Represents information about the value of a Result.
+ """
+
+ def __init__(self, result: "Result", backend: "BaseResultBackend") -> None:
+ self._result = result
+ self._backend = backend
+
+ self._value_holder: Optional[ResultValueHolder] = None
+
+ async def get_value_holder(self) -> ResultValueHolder:
+ if self._value_holder is None:
+ self._value_holder = await self._backend.get_value(self._result.id)
+
+ return self._value_holder
+
+ async def get_raw_value(self) -> Any:
+ """Get the value. In case this is an exception, it won't be raised."""
+ holder = await self.get_value_holder()
+ return holder.deserialize()
+
+ async def set_raw_value(self, value: Any):
+ if isinstance(value, BaseException):
+ value = _ExceptionInfo.from_exception(value)
+
+ holder = ResultValueHolder(raw_value=value, value_type=_serialize_name(value))
+ await self._backend.set_value(self._result.id, holder)
+
get_raw_value(self)
+
+
+ async
+
+
+Get the value. In case this is an exception, it won't be raised.
+ + +
+ResultValueHolder (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+Holds information about the type of the Result's value, and the raw value itself.
+Use deserialize()
to parse the value according to the type.
mognet/model/result.py
class ResultValueHolder(BaseModel):
+ """
+ Holds information about the type of the Result's value, and the raw value itself.
+
+ Use `deserialize()` to parse the value according to the type.
+ """
+
+ value_type: str
+ raw_value: Any
+
+ def deserialize(self) -> Any:
+ if self.raw_value is None:
+ return None
+
+ if self.value_type is not None:
+ cls = _get_attr(self.value_type)
+
+ value = parse_obj_as(cls, self.raw_value)
+ else:
+ value = self.raw_value
+
+ return value
+
+ @classmethod
+ def not_ready(cls):
+ """
+ Creates a value holder which is not ready yet.
+ """
+ value = _ExceptionInfo.from_exception(ResultNotReady())
+ return cls(value_type=_serialize_name(value), raw_value=value)
+
not_ready()
+
+
+ classmethod
+
+
+Creates a value holder which is not ready yet.
+ + +result_state
+
+
+
+
+ResultState (str, Enum)
+
+
+
+
+States that a task execution, and its result, can be in.
+ +mognet/model/result_state.py
class ResultState(str, Enum):
+ """
+ States that a task execution, and its result, can be in.
+ """
+
+ # The task associated with this result has not yet started.
+ PENDING = "PENDING"
+
+ # The task associated with this result is currently running.
+ RUNNING = "RUNNING"
+
+ # The task associated with this result was suspended, either because
+ # it yielded subtasks, or because the worker it was on was shut down
+ # gracefully.
+ SUSPENDED = "SUSPENDED"
+
+ # The task associated with this result finished successfully.
+ SUCCESS = "SUCCESS"
+
+ # The task associated with this result failed.
+ FAILURE = "FAILURE"
+
+ # The task associated with this result was aborted.
+ REVOKED = "REVOKED"
+
+ # Invalid task
+ INVALID = "INVALID"
+
+ def __repr__(self):
+ return f"{self.name!r}"
+
result_tree
+
+
+
+
+ResultTree (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+mognet/model/result_tree.py
class ResultTree(BaseModel):
+ result: "Result"
+ children: List["ResultTree"]
+
+ def __str__(self) -> str:
+ return f"{self.result.name}(id={self.result.id!r}, state={self.result.state!r}, node_id={self.result.node_id!r})"
+
+ def dict(self, **kwargs):
+ return {
+ "id": self.result.id,
+ "name": self.result.name,
+ "state": self.result.state,
+ "created": self.result.created,
+ "started": self.result.started,
+ "finished": self.result.finished,
+ "node_id": self.result.node_id,
+ "number_of_starts": self.result.number_of_starts,
+ "number_of_stops": self.result.number_of_stops,
+ "retry_count": self.result.unexpected_retry_count,
+ "children": [c.dict() for c in self.children],
+ }
+
__str__(self)
+
+
+ special
+
+
+dict(self, **kwargs)
+
+
+Generate a dictionary representation of the model, optionally specifying which fields to include or exclude.
+ +mognet/model/result_tree.py
def dict(self, **kwargs):
+ return {
+ "id": self.result.id,
+ "name": self.result.name,
+ "state": self.result.state,
+ "created": self.result.created,
+ "started": self.result.started,
+ "finished": self.result.finished,
+ "node_id": self.result.node_id,
+ "number_of_starts": self.result.number_of_starts,
+ "number_of_stops": self.result.number_of_stops,
+ "retry_count": self.result.unexpected_retry_count,
+ "children": [c.dict() for c in self.children],
+ }
+
primitives
+
+
+
+ special
+
+
+request
+
+
+
+
+Request (GenericModel, Generic)
+
+
+
+
+ pydantic-model
+
+
+mognet/primitives/request.py
class Request(GenericModel, Generic[TReturn]):
+ id: UUID = Field(default_factory=uuid4)
+ name: str
+
+ args: tuple = ()
+ kwargs: Dict[str, Any] = Field(default_factory=dict)
+
+ stack: List[UUID] = Field(default_factory=list)
+
+ # Metadata that's going to be put in the Result associated
+ # with this Request.
+ metadata: Dict[str, Any] = Field(default_factory=dict)
+
+ # Deadline to run this request.
+ # If it's a datetime, the deadline will be computed based on the difference
+ # to `datetime.now(tz=timezone.utc)`. If the deadline is already passed, the task
+ # is discarded and marked as `REVOKED`.
+ # If it's a timedelta, the task's coroutine will be given that long to run, after which
+ # it will be cancelled and marked as `REVOKED`.
+ # Note that, like with manual revoking, there is no guarantee that the timed out task will
+ # actually stop running.
+ deadline: Optional[Union[timedelta, datetime]] = None
+
+ # Overrides the queue the message will be sent to.
+ queue_name: Optional[str] = None
+
+ # Allow setting a kwargs representation for debugging purposes.
+ # This is stored on the corresponding Result.
+ # If not set, it's set when the request is submitted.
+ # Note that if there are arguments which contain sensitive data, this will leak their values,
+ # so you are responsible for ensuring such values are censored.
+ kwargs_repr: Optional[str] = None
+
+ # Task priority. The higher the value, the higher the priority.
+ priority: Priority = 5
+
+ def __repr__(self):
+ msg = f"{self.name}[id={self.id!r}]"
+
+ if self.kwargs_repr is not None:
+ msg += f"({self.kwargs_repr})"
+
+ return msg
+
service
+
+
+
+ special
+
+
+class_service
+
+
+
+
+ClassService (Generic)
+
+
+
+
+Base class for object-based services retrieved +through Context#get_service()
+To get instances of a class-based service using +Context#get_service(), pass the class itself, +and an instance of the class will be returned.
+Note that the instances are singletons.
+The instances, when created, get access to the app's configuration.
+ +mognet/service/class_service.py
class ClassService(Generic[_TReturn], metaclass=ABCMeta):
+ """
+ Base class for object-based services retrieved
+ through Context#get_service()
+
+ To get instances of a class-based service using
+ Context#get_service(), pass the class itself,
+ and an instance of the class will be returned.
+
+ Note that the instances are *singletons*.
+
+ The instances, when created, get access to the app's configuration.
+ """
+
+ def __init__(self, config: "AppConfig") -> None:
+ self.config = config
+
+ @abstractmethod
+ def __call__(self, context: "Context", *args, **kwds) -> _TReturn:
+ raise NotImplementedError
+
+ def __enter__(self):
+ return self
+
+ def __exit__(self, exc_type, exc_val, exc_tb):
+ self.close()
+
+ def close(self):
+ pass
+
+ async def wait_closed(self):
+ pass
+
state
+
+
+
+ special
+
+
+state
+
+
+
+
+State
+
+
+
+Represents state that can persist across task restarts.
+Has facilities for getting, setting, and removing values.
+The task's state is deleted when the task finishes.
+ +mognet/state/state.py
class State:
+ """
+ Represents state that can persist across task restarts.
+
+ Has facilities for getting, setting, and removing values.
+
+ The task's state is deleted when the task finishes.
+ """
+
+ request_id: UUID
+
+ def __init__(self, app: "App", request_id: UUID) -> None:
+ self._app = app
+ self.request_id = request_id
+
+ @property
+ def _backend(self):
+ return self._app.state_backend
+
+ async def get(self, key: str, default: Any = None) -> Any:
+ """Get a value."""
+ return await self._backend.get(self.request_id, key, default)
+
+ async def set(self, key: str, value: Any):
+ """Set a value."""
+ return await self._backend.set(self.request_id, key, value)
+
+ async def pop(self, key: str, default: Any = None) -> Any:
+ """Delete a value from the state and return it's value."""
+ return await self._backend.pop(self.request_id, key, default)
+
+ async def clear(self):
+ """Clear all values."""
+ return await self._backend.clear(self.request_id)
+
clear(self)
+
+
+ async
+
+
+get(self, key, default=None)
+
+
+ async
+
+
+pop(self, key, default=None)
+
+
+ async
+
+
+state_backend_config
+
+
+
+
+RedisStateBackendSettings (BaseModel)
+
+
+
+
+ pydantic-model
+
+
+Configuration for the Redis State Backend
+ +mognet/state/state_backend_config.py
class RedisStateBackendSettings(BaseModel):
+ """Configuration for the Redis State Backend"""
+
+ url: str = "redis://localhost:6379/"
+
+ # How long each task's state should live for.
+ state_ttl: int = 7200
+
+ # Set the limit of connections on the Redis connection pool.
+ # DANGER! Setting this to too low a value WILL cause issues opening connections!
+ max_connections: Optional[int] = None
+
testing
+
+
+
+ special
+
+
+pytest_integration
+
+
+
+create_app_fixture(app)
+
+
+Create a Pytest fixture for a Mognet application.
+ +mognet/testing/pytest_integration.py
def create_app_fixture(app: App):
+ """Create a Pytest fixture for a Mognet application."""
+
+ @pytest_asyncio.fixture
+ async def app_fixture():
+ async with app:
+ start_task = asyncio.create_task(app.start())
+ yield app
+ await app.close()
+
+ try:
+ start_task.cancel()
+ await start_task
+ except BaseException: # pylint: disable=broad-except
+ pass
+
+ return app_fixture
+
tools
+
+
+
+ special
+
+
+backports
+
+
+
+ special
+
+
+aioitertools
+
+
+
+Backport of https://github.com/RedRoserade/aioitertools/blob/f86552753e626cb71a3a305b9ec890f97d771e6b/aioitertools/asyncio.py#L93
+Should be upstreamed here: https://github.com/omnilib/aioitertools/pull/103
+ + + +as_generated(iterables, *, return_exceptions=False)
+
+
+Yield results from one or more async iterables, in the order they are produced.
+Like :func:as_completed
, but for async iterators or generators instead of futures.
+Creates a separate task to drain each iterable, and a single queue for results.
+If return_exceptions
is False
, then any exception will be raised, and
+pending iterables and tasks will be cancelled, and async generators will be closed.
+If return_exceptions
is True
, any exceptions will be yielded as results,
+and execution will continue until all iterables have been fully consumed.
+Example::
+ async def generator(x):
+ for i in range(x):
+ yield i
+ gen1 = generator(10)
+ gen2 = generator(12)
+ async for value in as_generated([gen1, gen2]):
+ ... # intermixed values yielded from gen1 and gen2
mognet/tools/backports/aioitertools.py
async def as_generated(
+ iterables: Iterable[AsyncIterable[T]],
+ *,
+ return_exceptions: bool = False,
+) -> AsyncIterable[T]:
+ """
+ Yield results from one or more async iterables, in the order they are produced.
+ Like :func:`as_completed`, but for async iterators or generators instead of futures.
+ Creates a separate task to drain each iterable, and a single queue for results.
+ If ``return_exceptions`` is ``False``, then any exception will be raised, and
+ pending iterables and tasks will be cancelled, and async generators will be closed.
+ If ``return_exceptions`` is ``True``, any exceptions will be yielded as results,
+ and execution will continue until all iterables have been fully consumed.
+ Example::
+ async def generator(x):
+ for i in range(x):
+ yield i
+ gen1 = generator(10)
+ gen2 = generator(12)
+ async for value in as_generated([gen1, gen2]):
+ ... # intermixed values yielded from gen1 and gen2
+ """
+
+ queue: asyncio.Queue[dict] = asyncio.Queue()
+
+ tailer_count: int = 0
+
+ async def tailer(iterable: AsyncIterable[T]) -> None:
+ nonlocal tailer_count
+
+ try:
+ async for item in iterable:
+ await queue.put({"value": item})
+ except asyncio.CancelledError:
+ if isinstance(iterable, AsyncGenerator): # pragma:nocover
+ with suppress(Exception):
+ await iterable.aclose()
+ raise
+ except Exception as exc: # pylint: disable=broad-except
+ await queue.put({"exception": exc})
+ finally:
+ tailer_count -= 1
+
+ if tailer_count == 0:
+ await queue.put({"done": True})
+
+ tasks = [asyncio.ensure_future(tailer(iter)) for iter in iterables]
+
+ if not tasks:
+ # Nothing to do
+ return
+
+ tailer_count = len(tasks)
+
+ try:
+ while True:
+ i = await queue.get()
+
+ if "value" in i:
+ yield i["value"]
+ elif "exception" in i:
+ if return_exceptions:
+ yield i["exception"]
+ else:
+ raise i["exception"]
+ elif "done" in i:
+ break
+ except (asyncio.CancelledError, GeneratorExit):
+ pass
+ finally:
+ for task in tasks:
+ if not task.done():
+ task.cancel()
+
+ for task in tasks:
+ with suppress(asyncio.CancelledError):
+ await task
+
kwargs_repr
+
+
+
+format_kwargs_repr(args, kwargs, *, value_max_length=64)
+
+
+Utility function to create an args + kwargs representation.
+ +mognet/tools/kwargs_repr.py
def format_kwargs_repr(
+ args: tuple,
+ kwargs: dict,
+ *,
+ value_max_length: Optional[int] = 64,
+) -> str:
+ """Utility function to create an args + kwargs representation."""
+
+ parts = []
+
+ for arg in args:
+ parts.append(_format_value(arg, max_length=value_max_length))
+
+ for arg_name, arg_value in kwargs.items():
+ parts.append(
+ f"{arg_name}={_format_value(arg_value, max_length=value_max_length)}"
+ )
+
+ return ", ".join(parts)
+
retries
+
+
+
+retryableasyncmethod(types, *, max_attempts, wait_timeout, lock=None, on_retry=None)
+
+
+Decorator to wrap an async method and make it retryable.
+ +mognet/tools/retries.py
def retryableasyncmethod(
+ types: Tuple[Type[BaseException], ...],
+ *,
+ max_attempts: Union[int, str],
+ wait_timeout: Union[float, str],
+ lock: Union[asyncio.Lock, str] = None,
+ on_retry: Union[Callable[[BaseException], Awaitable], str] = None,
+):
+ """
+ Decorator to wrap an async method and make it retryable.
+ """
+
+ def make_retryable(func: _T) -> _T:
+ if inspect.isasyncgenfunction(func):
+ raise TypeError("Async generator functions are not supported")
+
+ f: Any = cast(Any, func)
+
+ @wraps(f)
+ async def async_retryable_decorator(self, *args, **kwargs):
+ last_exc = None
+
+ retry = _noop
+ if isinstance(on_retry, str):
+ retry = getattr(self, on_retry)
+ elif callable(on_retry):
+ retry = on_retry
+
+ attempts: int
+ if isinstance(max_attempts, str):
+ attempts = getattr(self, max_attempts)
+ else:
+ attempts = max_attempts
+
+ timeout: float
+ if isinstance(wait_timeout, str):
+ timeout = getattr(self, wait_timeout)
+ else:
+ timeout = wait_timeout
+
+ retry_lock = None
+ if isinstance(lock, str):
+ retry_lock = getattr(self, lock)
+ elif lock is not None:
+ retry_lock = lock
+
+ # Use an exponential backoff, starting with 1s
+ # and with a maximum of whatever was configured
+ current_wait_timeout = min(1, timeout)
+
+ for attempt in range(1, attempts + 1):
+ try:
+ return await f(self, *args, **kwargs)
+ except types as exc:
+ _log.error("Attempt %r/%r failed", attempt, attempts, exc_info=exc)
+ last_exc = exc
+
+ _log.debug("Waiting %.2fs before next attempt", current_wait_timeout)
+
+ await asyncio.sleep(current_wait_timeout)
+
+ current_wait_timeout = min(current_wait_timeout * 2, timeout)
+
+ if retry_lock is not None:
+ if retry_lock.locked():
+ _log.debug("Already retrying, possibly on another method")
+ else:
+ async with retry_lock:
+ _log.debug("Calling retry method")
+ await retry(last_exc)
+ else:
+ await retry(last_exc)
+
+ if last_exc is None:
+ last_exc = Exception("All %r attempts failed" % attempts)
+
+ raise last_exc
+
+ return cast(_T, async_retryable_decorator)
+
+ return make_retryable
+
worker
+
+
+
+ special
+
+
+worker
+
+
+
+
+MessageCancellationAction (str, Enum)
+
+
+
+
+
+Worker
+
+
+
+Workers are responsible for running the fetch -> run -> store result +loop, for the task queues that are configured.
+ +mognet/worker/worker.py
class Worker:
+ """
+ Workers are responsible for running the fetch -> run -> store result
+ loop, for the task queues that are configured.
+ """
+
+ running_tasks: Dict[UUID, "_RequestProcessorHolder"]
+
+ # Set of tasks that are suspended
+ _waiting_tasks: Set[UUID]
+
+ app: "App"
+
+ def __init__(
+ self,
+ *,
+ app: "App",
+ middleware: List["Middleware"] = None,
+ ) -> None:
+ self.app = app
+ self.running_tasks = {}
+ self._waiting_tasks = set()
+ self._middleware = middleware or []
+
+ self._current_prefetch = 1
+
+ self._queue_consumption_tasks: List[AsyncGenerator] = []
+ self._consume_task = None
+
+ async def run(self):
+ _log.debug("Starting worker")
+
+ try:
+ self.app.broker.add_connection_failed_callback(self._handle_connection_lost)
+
+ await self.start_consuming()
+ except asyncio.CancelledError:
+ _log.debug("Stopping run")
+ return
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error("Error during consumption", exc_info=exc)
+
+ async def _handle_connection_lost(self, exc: BaseException = None):
+ _log.error("Handling connection lost event, stopping all tasks", exc_info=exc)
+
+ # No point in NACKing, because we have been disconnected
+ await self._cancel_all_tasks(message_action=MessageCancellationAction.NOTHING)
+
+ async def _cancel_all_tasks(self, *, message_action: MessageCancellationAction):
+ all_req_ids = list(self.running_tasks)
+
+ _log.debug("Cancelling all %r running tasks", len(all_req_ids))
+
+ try:
+ for req_id in all_req_ids:
+ await self.cancel(req_id, message_action=message_action)
+
+ await self._adjust_prefetch()
+ finally:
+ self._waiting_tasks.clear()
+
+ async def stop_consuming(self):
+ _log.debug("Closing queue consumption tasks")
+
+ consumers = self._queue_consumption_tasks
+ while consumers:
+ consumer = consumers.pop(0)
+
+ try:
+ await asyncio.wait_for(consumer.aclose(), 5)
+ except (asyncio.CancelledError, GeneratorExit, asyncio.TimeoutError):
+ pass
+ except Exception as consume_exc: # pylint: disable=broad-except
+ _log.debug("Error closing consumer", exc_info=consume_exc)
+
+ consume_task = self._consume_task
+ self._consume_task = None
+
+ if consume_task is not None:
+ _log.debug("Closing aggregation task")
+
+ try:
+ consume_task.cancel()
+ await asyncio.wait_for(consume_task, 15)
+
+ _log.debug("Closed consumption task")
+ except (asyncio.CancelledError, asyncio.TimeoutError):
+ pass
+ except Exception as consume_err: # pylint: disable=broad-except
+ _log.error("Error shutting down consumer task", exc_info=consume_err)
+
+ async def close(self):
+ """
+ Stops execution, cancelling all running tasks.
+ """
+
+ _log.debug("Closing worker")
+
+ await self.stop_consuming()
+
+ # Cancel and NACK all messages currently on this worker.
+ await self._cancel_all_tasks(message_action=MessageCancellationAction.NACK)
+
+ _log.debug("Closed worker")
+
+ def _remove_running_task(self, req_id: UUID):
+ fut = self.running_tasks.pop(req_id, None)
+
+ asyncio.create_task(self._emit_running_task_count_change())
+
+ return fut
+
+ def _add_running_task(self, req_id: UUID, holder: "_RequestProcessorHolder"):
+ self.running_tasks[req_id] = holder
+ asyncio.create_task(self._emit_running_task_count_change())
+
+ async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
+ """
+ Cancels, if any, the execution of a request.
+ Whoever calls this method is responsible for updating the result on the backend
+ accordingly.
+ """
+ fut = self._remove_running_task(req_id)
+
+ if fut is None:
+ _log.debug("Request id=%r is not running on this worker", req_id)
+ return
+
+ _log.info("Cancelling task %r", req_id)
+
+ result = await self.app.result_backend.get_or_create(req_id)
+
+ # Only suspend the result on the backend if it was running in our node.
+ if (
+ result.state == ResultState.RUNNING
+ and result.node_id == self.app.config.node_id
+ ):
+ await asyncio.shield(result.suspend())
+ _log.debug("Result for task %r suspended", req_id)
+
+ _log.debug("Waiting for coroutine of task %r to finish", req_id)
+
+ # Wait for the task to finish, this allows it to clean up.
+ try:
+ await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
+ except asyncio.TimeoutError:
+ _log.warning(
+ "Handler for task id=%r took longer than 15s to shut down", req_id
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Handler for task=%r failed while shutting down", req_id, exc_info=exc
+ )
+
+ _log.debug("Stopped handler of task id=%r", req_id)
+
+ def _create_context(self, request: "Request") -> "Context":
+ if not self.app.state_backend:
+ raise RuntimeError("No state backend defined")
+
+ return Context(
+ self.app,
+ request,
+ State(self.app, request.id),
+ self,
+ )
+
+ async def _run_request(self, req: Request) -> None:
+ """
+ Processes a request, validating it before running.
+ """
+
+ _log.debug("Received request %r", req)
+
+ # Check if we're not trying to (re) start something which is already done,
+ # for cases when a request is cancelled before it's started.
+ # Even worse, check that we're not trying to start a request whose
+ # result might have been evicted.
+ result = await self.app.result_backend.get(req.id)
+
+ if result is None:
+ _log.error(
+ "Attempting to run task %r, but it's result doesn't exist on the backend. Discarding",
+ req,
+ )
+ await self.remove_suspended_task(req.id)
+ return
+
+ context = self._create_context(req)
+
+ if result.done:
+ _log.error(
+ "Attempting to re-run task %r, when it's already done with state %r. Discarding",
+ req,
+ result.state,
+ )
+ return await asyncio.shield(self._on_complete(context, result))
+
+ # Check if we should even start, because:
+ # 1. We might be in a crash loop (if the process gets killed without cleanup, the numbers won't match),
+ # 2. Infinite recursion
+ # 3. We might be part of a parent request that was revoked (or doesn't exist)
+ # 4. We might be too late.
+
+ # 1. Too many starts
+ retry_count = result.unexpected_retry_count
+
+ if retry_count > 0:
+ _log.warning(
+ "Task %r has been retried %r times (max=%r)",
+ req.id,
+ retry_count,
+ self.app.config.max_retries,
+ )
+
+ if retry_count > self.app.config.max_retries:
+ _log.error(
+ "Discarding task %r because it has exceeded the maximum retry count of %r",
+ req,
+ self.app.config.max_retries,
+ )
+
+ result = await result.set_error(
+ TooManyRetries(req.id, retry_count, self.app.config.max_retries)
+ )
+
+ return await asyncio.shield(self._on_complete(context, result))
+
+ if req.stack:
+
+ # 2. Recursion
+ if len(req.stack) > self.app.config.max_recursion:
+ result = await result.set_error(RecursionError())
+ return await asyncio.shield(self._on_complete(context, result))
+
+ # 3. Parent task(s) aborted (or doesn't exist)
+ for parent_id in reversed(req.stack):
+ parent_result = await self.app.result_backend.get(parent_id)
+
+ if parent_result is None:
+ result = await result.set_error(
+ Exception(f"Parent request id={parent_id} does not exist"),
+ state=ResultState.REVOKED,
+ )
+ return await asyncio.shield(self._on_complete(context, result))
+
+ if parent_result.state == ResultState.REVOKED:
+ result = await result.set_error(
+ Exception(f"Parent request id={parent_result.id} was revoked"),
+ state=ResultState.REVOKED,
+ )
+ return await asyncio.shield(self._on_complete(context, result))
+
+ # 4. Request arrived past the deadline.
+
+ if isinstance(req.deadline, datetime):
+ # One cannot compare naive and aware datetime,
+ # so create equivalent datetime objects.
+ now = datetime.now(tz=req.deadline.tzinfo)
+
+ if req.deadline < now:
+ _log.error(
+ "Request %r arrived too late. Deadline is %r, current date is %r. Marking it as REVOKED and discarding",
+ req,
+ req.deadline,
+ now,
+ )
+ result = await asyncio.shield(
+ result.set_error(asyncio.TimeoutError(), state=ResultState.REVOKED)
+ )
+ return await asyncio.shield(self._on_complete(context, result))
+
+ # Get the function for the task. Fail if the task is not registered in
+ # our app's context.
+ try:
+ task_function = self.app.task_registry.get_task_function(req.name)
+ except UnknownTask as unknown_task:
+ _log.error(
+ "Request %r is for an unknown task: %r",
+ req,
+ req.name,
+ exc_info=unknown_task,
+ )
+ result = await result.set_error(unknown_task, state=ResultState.INVALID)
+ return await asyncio.shield(self._on_complete(context, result))
+
+ # Mark this as running.
+ await result.start(node_id=self.app.config.node_id)
+
+ await asyncio.shield(self._on_starting(context))
+
+ try:
+ # Create a validated version of the function.
+ # This not only does argument validation, but it also parses the values
+ # into objects.
+ validated = ValidatedFunction(
+ task_function, config=_TaskFuncArgumentValidationConfig
+ )
+
+ # This does the model validation part.
+ model = validated.init_model_instance(context, *req.args, **req.kwargs)
+
+ if inspect.iscoroutinefunction(task_function):
+ fut = validated.execute(model)
+ else:
+ _log.debug(
+ "Handler for task %r is not a coroutine function, running in the loop's default executor",
+ req.name,
+ )
+
+ # Run non-coroutine functions inside an executor.
+ # This allows them to run without blocking the event loop
+ # (providing the GIL does not block it either)
+ fut = self.app.loop.run_in_executor(None, validated.execute, model)
+
+ except ValidationError as exc:
+ _log.error(
+ "Could not call task function %r because of a validation error",
+ task_function,
+ exc_info=exc,
+ )
+
+ invalid = InvalidTaskArguments.from_validation_error(exc)
+
+ result = await asyncio.shield(
+ result.set_error(invalid, state=ResultState.INVALID)
+ )
+
+ return await asyncio.shield(self._on_complete(context, result))
+
+ if req.deadline is not None:
+ if isinstance(req.deadline, datetime):
+ # One cannot compare naive and aware datetime,
+ # so create equivalent datetime objects.
+ now = datetime.now(tz=req.deadline.tzinfo)
+ timeout = (req.deadline - now).total_seconds()
+ else:
+ timeout = req.deadline.total_seconds()
+
+ _log.debug("Applying %.2fs timeout to request %r", timeout, req)
+
+ fut = asyncio.wait_for(fut, timeout=timeout)
+
+ # Start executing.
+ try:
+ value = await fut
+
+ if req.id in self.running_tasks:
+ await asyncio.shield(result.set_result(value))
+
+ _log.info(
+ "Request %r finished with status %r in %.2fs",
+ req,
+ result.state,
+ (result.duration or timedelta()).total_seconds(),
+ )
+
+ await asyncio.shield(self._on_complete(context, result))
+ except Pause:
+ _log.info(
+ "Handler for %r requested to be paused. Suspending it on the Result Backend and NACKing the message",
+ req,
+ )
+
+ holder = self.running_tasks.pop(req.id, None)
+ if holder is not None:
+ await asyncio.shield(result.suspend())
+ await asyncio.shield(holder.message.nack())
+ except asyncio.CancelledError:
+ _log.debug("Handler for task %r cancelled", req)
+
+ # Re-raise the cancellation, this will be caught in the parent function
+ # and prevent ack/nack
+ raise
+ except Exception as exc: # pylint: disable=broad-except
+ state = ResultState.FAILURE
+
+ # The task's coroutine may raise `asyncio.TimeoutError` itself, so there's
+ # no guarantee that the timeout we catch is actually related to the request's timeout.
+ # So, this heuristic is not the best.
+ # TODO: A way to improve it would be to double-check if the deadline itself is expired.
+ if req.deadline is not None and isinstance(exc, asyncio.TimeoutError):
+ state = ResultState.REVOKED
+
+ if req.id in self.running_tasks:
+ result = await asyncio.shield(result.set_error(exc, state=state))
+ await asyncio.shield(self._on_complete(context, result))
+
+ duration = result.duration
+
+ if duration is not None:
+ _log.error(
+ "Handler for task %r failed in %.2fs with state %r",
+ req,
+ duration.total_seconds(),
+ state,
+ exc_info=exc,
+ )
+ else:
+ _log.error(
+ "Handler for task %r failed with state %r",
+ req,
+ state,
+ exc_info=exc,
+ )
+
+ async def _on_complete(self, context: "Context", result: Result):
+ if result.done:
+ await context.state.clear()
+
+ await self.remove_suspended_task(context.request.id)
+
+ for middleware in self._middleware:
+ try:
+ _log.debug("Calling 'on_task_completed' middleware: %r", middleware)
+ await asyncio.shield(
+ middleware.on_task_completed(result, context=context)
+ )
+ except Exception as mw_exc: # pylint: disable=broad-except
+ _log.error("Middleware %r failed", middleware, exc_info=mw_exc)
+
+ async def _on_starting(self, context: "Context"):
+ _log.info("Starting task %r", context.request)
+
+ for middleware in self._middleware:
+ try:
+ _log.debug("Calling 'on_task_starting' middleware: %r", middleware)
+ await asyncio.shield(middleware.on_task_starting(context))
+ except Exception as mw_exc: # pylint: disable=broad-except
+ _log.error("Middleware %r failed", middleware, exc_info=mw_exc)
+
+ def _process_request_message(self, payload: IncomingMessagePayload) -> asyncio.Task:
+ """
+ Creates an asyncio.Task which will process the enclosed Request
+ in the background.
+
+ Returns said task, after adding completion handlers to it.
+ """
+ _log.debug("Parsing input of message id=%r as Request", payload.id)
+ req = Request.parse_obj(payload.payload)
+
+ async def request_processor():
+ try:
+ await self._run_request(req)
+
+ _log.debug("ACK message id=%r for request=%r", payload.id, req)
+ await asyncio.shield(payload.ack())
+ except asyncio.CancelledError:
+ _log.debug("Cancelled execution of request=%r", req)
+ return
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Fatal error processing request=%r, NAK message id=%r",
+ req,
+ payload.id,
+ exc_info=exc,
+ )
+ await asyncio.shield(payload.nack())
+
+ def on_processing_done(fut: Future):
+ self._remove_running_task(req.id)
+
+ exc = fut.exception()
+
+ if exc is not None and not fut.cancelled():
+ _log.error("Fatal error processing %r", req, exc_info=exc)
+ else:
+ _log.debug("Processed %r successfully", req)
+
+ task = asyncio.create_task(request_processor())
+ task.add_done_callback(on_processing_done)
+
+ holder = _RequestProcessorHolder(payload, req, task)
+
+ self._add_running_task(req.id, holder)
+
+ return task
+
+ def start_consuming(self):
+ if self._consume_task is not None:
+ return self._consume_task
+
+ self._consume_task = asyncio.create_task(self._start_consuming())
+
+ return self._consume_task
+
+ async def _start_consuming(self):
+
+ queues = self.app.get_task_queue_names()
+
+ _log.info("Going to consume %r queues", len(queues))
+
+ try:
+ await self._adjust_prefetch()
+
+ for queue in queues:
+ _log.info("Start consuming task queue=%r", queue)
+ self._queue_consumption_tasks.append(
+ self.app.broker.consume_tasks(queue)
+ )
+
+ async for payload in as_generated(self._queue_consumption_tasks):
+ try:
+ if payload.kind == "Request":
+ self._process_request_message(payload)
+ else:
+ raise ValueError(f"Unknown kind={payload.kind!r}")
+ except asyncio.CancelledError:
+ break
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Error processing message=%r, discarding it",
+ payload,
+ exc_info=exc,
+ )
+ await asyncio.shield(payload.ack())
+ finally:
+ _log.debug("Stopped consuming task queues")
+
+ async def add_waiting_task(self, task_id: UUID):
+ self._waiting_tasks.add(task_id)
+ await self._adjust_prefetch()
+
+ async def remove_suspended_task(self, task_id: UUID):
+ try:
+ self._waiting_tasks.remove(task_id)
+ except KeyError:
+ pass
+ await self._adjust_prefetch()
+
+ @property
+ def waiting_task_count(self):
+ return len(self._waiting_tasks)
+
+ async def _emit_running_task_count_change(self):
+ for middleware in self._middleware:
+ try:
+ _log.debug("Calling 'on_running_task_count_changed' on %r", middleware)
+ await middleware.on_running_task_count_changed(len(self.running_tasks))
+ except Exception as mw_exc: # pylint: disable=broad-except
+ _log.error(
+ "'on_running_task_count_changed' failed on %r",
+ middleware,
+ exc_info=mw_exc,
+ )
+
+ async def _adjust_prefetch(self):
+ if self._consume_task is None:
+ _log.debug("Not adjusting prefetch because not consuming the queue")
+ return
+
+ minimum_prefetch = self.app.config.minimum_concurrency
+
+ prefetch = self.waiting_task_count + minimum_prefetch
+
+ max_prefetch = self.app.config.maximum_concurrency
+
+ if max_prefetch is not None and prefetch >= max_prefetch:
+ _log.error(
+ "Maximum prefetch value of %r reached! No more tasks will be fetched on this node",
+ max_prefetch,
+ )
+ prefetch = max_prefetch
+
+ if prefetch == self._current_prefetch:
+ _log.debug(
+ "Current prefetch is the same as the new prefetch (%r), not adjusting it",
+ prefetch,
+ )
+ return
+
+ _log.debug(
+ "Currently have %r tasks suspended waiting for others. Setting prefetch=%r from previous=%r",
+ self.waiting_task_count,
+ prefetch,
+ self._current_prefetch,
+ )
+
+ self._current_prefetch = prefetch
+ await self.app.broker.set_task_prefetch(self._current_prefetch)
+
cancel(self, req_id, *, message_action)
+
+
+ async
+
+
+Cancels, if any, the execution of a request. +Whoever calls this method is responsible for updating the result on the backend +accordingly.
+ +mognet/worker/worker.py
async def cancel(self, req_id: UUID, *, message_action: MessageCancellationAction):
+ """
+ Cancels, if any, the execution of a request.
+ Whoever calls this method is responsible for updating the result on the backend
+ accordingly.
+ """
+ fut = self._remove_running_task(req_id)
+
+ if fut is None:
+ _log.debug("Request id=%r is not running on this worker", req_id)
+ return
+
+ _log.info("Cancelling task %r", req_id)
+
+ result = await self.app.result_backend.get_or_create(req_id)
+
+ # Only suspend the result on the backend if it was running in our node.
+ if (
+ result.state == ResultState.RUNNING
+ and result.node_id == self.app.config.node_id
+ ):
+ await asyncio.shield(result.suspend())
+ _log.debug("Result for task %r suspended", req_id)
+
+ _log.debug("Waiting for coroutine of task %r to finish", req_id)
+
+ # Wait for the task to finish, this allows it to clean up.
+ try:
+ await asyncio.wait_for(fut.cancel(message_action=message_action), 15)
+ except asyncio.TimeoutError:
+ _log.warning(
+ "Handler for task id=%r took longer than 15s to shut down", req_id
+ )
+ except Exception as exc: # pylint: disable=broad-except
+ _log.error(
+ "Handler for task=%r failed while shutting down", req_id, exc_info=exc
+ )
+
+ _log.debug("Stopped handler of task id=%r", req_id)
+
close(self)
+
+
+ async
+
+
+Stops execution, cancelling all running tasks.
+ +mognet/worker/worker.py
async def close(self):
+ """
+ Stops execution, cancelling all running tasks.
+ """
+
+ _log.debug("Closing worker")
+
+ await self.stop_consuming()
+
+ # Cancel and NACK all messages currently on this worker.
+ await self._cancel_all_tasks(message_action=MessageCancellationAction.NACK)
+
+ _log.debug("Closed worker")
+
\n {translation(\"search.result.term.missing\")}: {...missing}\n
\n }\n