Skip to content

Commit

Permalink
fix: task queue tasks should error when they raise an exception (#783)
Browse files Browse the repository at this point in the history
- Fix tasks being marked as completed when they raise an exception and
retry_for is not explicitly set
- Fix type hinting for retry_for in task_queue decorator
- Fix potential scoping issues with args and kwargs
- Remove unimplemented retry_for parameter from functions

Resolve BE-2128
  • Loading branch information
nickpetrovic authored Dec 12, 2024
1 parent 359157a commit 7ca84e7
Show file tree
Hide file tree
Showing 5 changed files with 17 additions and 16 deletions.
2 changes: 1 addition & 1 deletion sdk/pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "beta9"
version = "0.1.132"
version = "0.1.133"
description = ""
authors = ["beam.cloud <[email protected]>"]
packages = [
Expand Down
2 changes: 1 addition & 1 deletion sdk/src/beta9/__init__.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from . import env
from .abstractions import experimental
from .abstractions import experimental, integrations
from .abstractions.container import Container
from .abstractions.endpoint import ASGI as asgi
from .abstractions.endpoint import Endpoint as endpoint
Expand Down
4 changes: 0 additions & 4 deletions sdk/src/beta9/abstractions/function.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,8 +66,6 @@ class Function(RunnerAbstraction):
task_policy (TaskPolicy):
The task policy for the function. This helps manage the lifecycle of an individual task.
Setting values here will override timeout and retries.
retry_for (Optional[List[BaseException]]):
A list of exceptions that will trigger a retry.
Example:
```python
from beta9 import function, Image
Expand Down Expand Up @@ -101,7 +99,6 @@ def __init__(
secrets: Optional[List[str]] = None,
name: Optional[str] = None,
task_policy: TaskPolicy = TaskPolicy(),
retry_for: Optional[List[Exception]] = None,
) -> None:
super().__init__(
cpu=cpu,
Expand All @@ -120,7 +117,6 @@ def __init__(

self._function_stub: Optional[FunctionServiceStub] = None
self.syncer: FileSyncer = FileSyncer(self.gateway_stub)
self.retry_for = retry_for

def __call__(self, func):
return _CallableWrapper(func, self)
Expand Down
6 changes: 3 additions & 3 deletions sdk/src/beta9/abstractions/taskqueue.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import json
import os
import threading
from typing import Any, Callable, List, Optional, Union
from typing import Any, Callable, List, Optional, Type, Union

from .. import terminal
from ..abstractions.base.runner import (
Expand Down Expand Up @@ -131,7 +131,7 @@ def __init__(
autoscaler: Autoscaler = QueueDepthAutoscaler(),
task_policy: TaskPolicy = TaskPolicy(),
checkpoint_enabled: bool = False,
retry_for: Optional[List[BaseException]] = None,
retry_for: Optional[List[Type[Exception]]] = None,
) -> None:
super().__init__(
cpu=cpu,
Expand All @@ -155,7 +155,7 @@ def __init__(
checkpoint_enabled=checkpoint_enabled,
)
self._taskqueue_stub: Optional[TaskQueueServiceStub] = None
self.retry_for = retry_for
self.retry_for = retry_for or []

@property
def taskqueue_stub(self) -> TaskQueueServiceStub:
Expand Down
19 changes: 12 additions & 7 deletions sdk/src/beta9/runner/taskqueue.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
from concurrent.futures import CancelledError
from multiprocessing import Event, Process, set_start_method
from multiprocessing.synchronize import Event as TEvent
from typing import Any, List, NamedTuple, Union
from typing import Any, List, NamedTuple, Type, Union

import grpc

Expand Down Expand Up @@ -312,20 +312,21 @@ def process_tasks(self, channel: Channel) -> None:
duration = None

caught_exception = ""
args = task.args or []
kwargs = task.kwargs or {}

try:
args = task.args or []
kwargs = task.kwargs or {}

result = handler(context, *args, **kwargs)
except BaseException as e:
print(traceback.format_exc())

if type(e) in handler.parent_abstraction.retry_for:
task_status = TaskStatus.Error

if retry_on_errors(handler.parent_abstraction.retry_for, e):
print(f"retry_for error caught: {e!r}")
caught_exception = e.__class__.__name__
task_status = TaskStatus.Retry
else:
task_status = TaskStatus.Error

finally:
duration = time.time() - start_time

Expand Down Expand Up @@ -369,6 +370,10 @@ def process_tasks(self, channel: Channel) -> None:
monitor_task.cancel()


def retry_on_errors(errors: List[Type[Exception]], e: BaseException) -> bool:
return any([err for err in errors if type(e) is err])


if __name__ == "__main__":
tq = TaskQueueManager()
tq.run()
Expand Down

0 comments on commit 7ca84e7

Please sign in to comment.