From 7ca84e79b46bb20eae2c68dae6a7d3d49529a02b Mon Sep 17 00:00:00 2001 From: nickpetrovic <4001122+nickpetrovic@users.noreply.github.com> Date: Thu, 12 Dec 2024 16:35:32 -0500 Subject: [PATCH] fix: task queue tasks should error when they raise an exception (#783) - Fix tasks being marked as completed when they raise an exception and retry_for is not explicitly set - Fix type hinting for retry_for in task_queue decorator - Fix potential scoping issues with args and kwargs - Remove unimplemented retry_for parameter from functions Resolve BE-2128 --- sdk/pyproject.toml | 2 +- sdk/src/beta9/__init__.py | 2 +- sdk/src/beta9/abstractions/function.py | 4 ---- sdk/src/beta9/abstractions/taskqueue.py | 6 +++--- sdk/src/beta9/runner/taskqueue.py | 19 ++++++++++++------- 5 files changed, 17 insertions(+), 16 deletions(-) diff --git a/sdk/pyproject.toml b/sdk/pyproject.toml index f1125b0db..d52010879 100644 --- a/sdk/pyproject.toml +++ b/sdk/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "beta9" -version = "0.1.132" +version = "0.1.133" description = "" authors = ["beam.cloud "] packages = [ diff --git a/sdk/src/beta9/__init__.py b/sdk/src/beta9/__init__.py index cbdca68a2..e80dd5a3b 100644 --- a/sdk/src/beta9/__init__.py +++ b/sdk/src/beta9/__init__.py @@ -1,5 +1,5 @@ from . import env -from .abstractions import experimental +from .abstractions import experimental, integrations from .abstractions.container import Container from .abstractions.endpoint import ASGI as asgi from .abstractions.endpoint import Endpoint as endpoint diff --git a/sdk/src/beta9/abstractions/function.py b/sdk/src/beta9/abstractions/function.py index 65ec006fb..42f2252e4 100644 --- a/sdk/src/beta9/abstractions/function.py +++ b/sdk/src/beta9/abstractions/function.py @@ -66,8 +66,6 @@ class Function(RunnerAbstraction): task_policy (TaskPolicy): The task policy for the function. This helps manage the lifecycle of an individual task. Setting values here will override timeout and retries. - retry_for (Optional[List[BaseException]]): - A list of exceptions that will trigger a retry. Example: ```python from beta9 import function, Image @@ -101,7 +99,6 @@ def __init__( secrets: Optional[List[str]] = None, name: Optional[str] = None, task_policy: TaskPolicy = TaskPolicy(), - retry_for: Optional[List[Exception]] = None, ) -> None: super().__init__( cpu=cpu, @@ -120,7 +117,6 @@ def __init__( self._function_stub: Optional[FunctionServiceStub] = None self.syncer: FileSyncer = FileSyncer(self.gateway_stub) - self.retry_for = retry_for def __call__(self, func): return _CallableWrapper(func, self) diff --git a/sdk/src/beta9/abstractions/taskqueue.py b/sdk/src/beta9/abstractions/taskqueue.py index 159b9a12b..eb979d5bd 100644 --- a/sdk/src/beta9/abstractions/taskqueue.py +++ b/sdk/src/beta9/abstractions/taskqueue.py @@ -1,7 +1,7 @@ import json import os import threading -from typing import Any, Callable, List, Optional, Union +from typing import Any, Callable, List, Optional, Type, Union from .. import terminal from ..abstractions.base.runner import ( @@ -131,7 +131,7 @@ def __init__( autoscaler: Autoscaler = QueueDepthAutoscaler(), task_policy: TaskPolicy = TaskPolicy(), checkpoint_enabled: bool = False, - retry_for: Optional[List[BaseException]] = None, + retry_for: Optional[List[Type[Exception]]] = None, ) -> None: super().__init__( cpu=cpu, @@ -155,7 +155,7 @@ def __init__( checkpoint_enabled=checkpoint_enabled, ) self._taskqueue_stub: Optional[TaskQueueServiceStub] = None - self.retry_for = retry_for + self.retry_for = retry_for or [] @property def taskqueue_stub(self) -> TaskQueueServiceStub: diff --git a/sdk/src/beta9/runner/taskqueue.py b/sdk/src/beta9/runner/taskqueue.py index f8a0c9010..3553a75fb 100644 --- a/sdk/src/beta9/runner/taskqueue.py +++ b/sdk/src/beta9/runner/taskqueue.py @@ -9,7 +9,7 @@ from concurrent.futures import CancelledError from multiprocessing import Event, Process, set_start_method from multiprocessing.synchronize import Event as TEvent -from typing import Any, List, NamedTuple, Union +from typing import Any, List, NamedTuple, Type, Union import grpc @@ -312,20 +312,21 @@ def process_tasks(self, channel: Channel) -> None: duration = None caught_exception = "" + args = task.args or [] + kwargs = task.kwargs or {} try: - args = task.args or [] - kwargs = task.kwargs or {} - result = handler(context, *args, **kwargs) except BaseException as e: print(traceback.format_exc()) - if type(e) in handler.parent_abstraction.retry_for: + task_status = TaskStatus.Error + + if retry_on_errors(handler.parent_abstraction.retry_for, e): + print(f"retry_for error caught: {e!r}") caught_exception = e.__class__.__name__ task_status = TaskStatus.Retry - else: - task_status = TaskStatus.Error + finally: duration = time.time() - start_time @@ -369,6 +370,10 @@ def process_tasks(self, channel: Channel) -> None: monitor_task.cancel() +def retry_on_errors(errors: List[Type[Exception]], e: BaseException) -> bool: + return any([err for err in errors if type(e) is err]) + + if __name__ == "__main__": tq = TaskQueueManager() tq.run()