Skip to content

Commit

Permalink
replaced validatedFunction with v2 validate_call. replaced missing pa…
Browse files Browse the repository at this point in the history
…rse_raw_as with validate_json

Signed-off-by: Tiago Santana <[email protected]>
  • Loading branch information
SantanaTiago committed May 21, 2024
1 parent 2136f5a commit 433f8fd
Show file tree
Hide file tree
Showing 3 changed files with 36 additions and 29 deletions.
4 changes: 2 additions & 2 deletions mognet/backend/redis_result_backend.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from typing import Any, AnyStr, Dict, Iterable, List, Optional, Set
from uuid import UUID

from pydantic.v1.tools import parse_raw_as
from pydantic import TypeAdapter
from redis.asyncio import Redis, from_url
from redis.exceptions import ConnectionError, TimeoutError

Expand Down Expand Up @@ -322,7 +322,7 @@ async def waiter():
while True:
raw_state = await shield(self._redis.hget(key, "state")) or b"null"

state = parse_raw_as(t, raw_state)
state = TypeAdapter(t).validate_json(raw_state)

if state is None:
raise ResultValueLost(result_id)
Expand Down
1 change: 0 additions & 1 deletion mognet/primitives/request.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from uuid import UUID, uuid4
from pydantic import BaseModel, Field

from pydantic.fields import Field
from typing_extensions import Annotated


Expand Down
60 changes: 34 additions & 26 deletions mognet/worker/worker.py
Original file line number Diff line number Diff line change
@@ -1,32 +1,25 @@
from datetime import datetime, timedelta
from enum import Enum
import inspect

from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause
import asyncio
import functools
import inspect
import logging
from asyncio.futures import Future
from mognet.broker.base_broker import IncomingMessagePayload
from typing import (
AsyncGenerator,
Optional,
Set,
TYPE_CHECKING,
Dict,
List,
)
from datetime import datetime, timedelta
from enum import Enum
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Set
from uuid import UUID

from mognet.tools.backports.aioitertools import as_generated
from pydantic import ValidationError as V2ValidationError
from pydantic import validate_call

from mognet.broker.base_broker import IncomingMessagePayload
from mognet.context.context import Context
from mognet.exceptions.task_exceptions import InvalidTaskArguments, Pause
from mognet.exceptions.too_many_retries import TooManyRetries
from mognet.model.result import Result, ResultState
from mognet.primitives.request import Request
from mognet.state.state import State
from mognet.tasks.task_registry import UnknownTask
from pydantic.v1 import ValidationError
from pydantic.v1.decorator import ValidatedFunction
from mognet.primitives.request import Request
from mognet.tools.backports.aioitertools import as_generated

if TYPE_CHECKING:
from mognet.app.app import App
Expand Down Expand Up @@ -336,15 +329,13 @@ async def _run_request(self, req: Request) -> None:
# Create a validated version of the function.
# This not only does argument validation, but it also parses the values
# into objects.
validated = ValidatedFunction(

validated_func = validate_call(
task_function, config=_TaskFuncArgumentValidationConfig
)

# This does the model validation part.
model = validated.init_model_instance(context, *req.args, **req.kwargs)

if inspect.iscoroutinefunction(task_function):
fut = validated.execute(model)
fut = validated_func(context, *req.args, **req.kwargs)
else:
_log.debug(
"Handler for task %r is not a coroutine function, running in the loop's default executor",
Expand All @@ -354,9 +345,12 @@ async def _run_request(self, req: Request) -> None:
# Run non-coroutine functions inside an executor.
# This allows them to run without blocking the event loop
# (providing the GIL does not block it either)
fut = self.app.loop.run_in_executor(None, validated.execute, model)
fut = self.app.loop.run_in_executor(
None,
functools.partial(validated_func, context, *req.args, **req.kwargs),
)

except ValidationError as exc:
except V2ValidationError as exc:
_log.error(
"Could not call task function %r because of a validation error",
task_function,
Expand Down Expand Up @@ -415,6 +409,20 @@ async def _run_request(self, req: Request) -> None:
# Re-raise the cancellation, this will be caught in the parent function
# and prevent ack/nack
raise
except V2ValidationError as exc: # will enter for sync functions here.
_log.error(
"Could not call task function %r because of a validation error",
task_function,
exc_info=exc,
)

invalid = InvalidTaskArguments.from_validation_error(exc)

result = await asyncio.shield(
result.set_error(invalid, state=ResultState.INVALID)
)

return await asyncio.shield(self._on_complete(context, result))
except Exception as exc: # pylint: disable=broad-except
state = ResultState.FAILURE

Expand Down Expand Up @@ -678,4 +686,4 @@ async def cancel(self, *, message_action: MessageCancellationAction):
except asyncio.CancelledError:
pass
except Exception as exc:
_log.error("Task handler for %r failed", self.request, exc_info=exc)
_log.error("Task handler for %r failed", self.request, exc_info=exc)

0 comments on commit 433f8fd

Please sign in to comment.