diff --git a/mognet/backend/redis_result_backend.py b/mognet/backend/redis_result_backend.py index 6d287aa..4a786ef 100644 --- a/mognet/backend/redis_result_backend.py +++ b/mognet/backend/redis_result_backend.py @@ -10,7 +10,7 @@ from typing import Any, AnyStr, Dict, Iterable, List, Optional, Set from uuid import UUID -from pydantic.v1.tools import parse_raw_as +from pydantic import TypeAdapter from redis.asyncio import Redis, from_url from redis.exceptions import ConnectionError, TimeoutError @@ -322,7 +322,7 @@ async def waiter(): while True: raw_state = await shield(self._redis.hget(key, "state")) or b"null" - state = parse_raw_as(t, raw_state) + state = TypeAdapter(t).validate_json(raw_state) if state is None: raise ResultValueLost(result_id) diff --git a/mognet/primitives/request.py b/mognet/primitives/request.py index 1e2fc47..557ebbc 100644 --- a/mognet/primitives/request.py +++ b/mognet/primitives/request.py @@ -3,7 +3,6 @@ from uuid import UUID, uuid4 from pydantic import BaseModel, Field -from pydantic.fields import Field from typing_extensions import Annotated diff --git a/mognet/worker/worker.py b/mognet/worker/worker.py index 024528d..83294ca 100644 --- a/mognet/worker/worker.py +++ b/mognet/worker/worker.py @@ -1,32 +1,25 @@ -from datetime import datetime, timedelta -from enum import Enum -import inspect - -from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause import asyncio +import functools +import inspect import logging from asyncio.futures import Future -from mognet.broker.base_broker import IncomingMessagePayload -from typing import ( - AsyncGenerator, - Optional, - Set, - TYPE_CHECKING, - Dict, - List, -) +from datetime import datetime, timedelta +from enum import Enum +from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Set from uuid import UUID -from mognet.tools.backports.aioitertools import as_generated +from pydantic import ValidationError as V2ValidationError +from pydantic import validate_call + from mognet.broker.base_broker import IncomingMessagePayload from mognet.context.context import Context +from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause from mognet.exceptions.too_many_retries import TooManyRetries from mognet.model.result import Result, ResultState +from mognet.primitives.request import Request from mognet.state.state import State from mognet.tasks.task_registry import UnknownTask -from pydantic.v1 import ValidationError -from pydantic.v1.decorator import ValidatedFunction -from mognet.primitives.request import Request +from mognet.tools.backports.aioitertools import as_generated if TYPE_CHECKING: from mognet.app.app import App @@ -336,15 +329,13 @@ async def _run_request(self, req: Request) -> None: # Create a validated version of the function. # This not only does argument validation, but it also parses the values # into objects. - validated = ValidatedFunction( + + validated_func = validate_call( task_function, config=_TaskFuncArgumentValidationConfig ) - # This does the model validation part. - model = validated.init_model_instance(context, *req.args, **req.kwargs) - if inspect.iscoroutinefunction(task_function): - fut = validated.execute(model) + fut = validated_func(context, *req.args, **req.kwargs) else: _log.debug( "Handler for task %r is not a coroutine function, running in the loop's default executor", @@ -354,9 +345,12 @@ async def _run_request(self, req: Request) -> None: # Run non-coroutine functions inside an executor. # This allows them to run without blocking the event loop # (providing the GIL does not block it either) - fut = self.app.loop.run_in_executor(None, validated.execute, model) + fut = self.app.loop.run_in_executor( + None, + functools.partial(validated_func, context, *req.args, **req.kwargs), + ) - except ValidationError as exc: + except V2ValidationError as exc: _log.error( "Could not call task function %r because of a validation error", task_function, @@ -415,6 +409,20 @@ async def _run_request(self, req: Request) -> None: # Re-raise the cancellation, this will be caught in the parent function # and prevent ack/nack raise + except V2ValidationError as exc: # will enter for sync functions here. + _log.error( + "Could not call task function %r because of a validation error", + task_function, + exc_info=exc, + ) + + invalid = InvalidTaskArguments.from_validation_error(exc) + + result = await asyncio.shield( + result.set_error(invalid, state=ResultState.INVALID) + ) + + return await asyncio.shield(self._on_complete(context, result)) except Exception as exc: # pylint: disable=broad-except state = ResultState.FAILURE @@ -678,4 +686,4 @@ async def cancel(self, *, message_action: MessageCancellationAction): except asyncio.CancelledError: pass except Exception as exc: - _log.error("Task handler for %r failed", self.request, exc_info=exc) + _log.error("Task handler for %r failed", self.request, exc_info=exc) \ No newline at end of file