diff --git a/flytekit/__init__.py b/flytekit/__init__.py index 74a50aff7f..bb68c5416b 100644 --- a/flytekit/__init__.py +++ b/flytekit/__init__.py @@ -243,6 +243,7 @@ from flytekit.core.reference import get_reference_entity from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference from flytekit.core.resources import Resources +from flytekit.core.retry import Backoff, OnOOM, Retry from flytekit.core.schedule import CronSchedule, FixedRate from flytekit.core.task import Secret, eager, reference_task, task from flytekit.core.type_engine import BatchSize diff --git a/flytekit/core/base_task.py b/flytekit/core/base_task.py index 41da032fee..79f3643c85 100644 --- a/flytekit/core/base_task.py +++ b/flytekit/core/base_task.py @@ -66,6 +66,7 @@ flyte_entity_call_handler, translate_inputs_to_literals, ) +from flytekit.core.retry import Retry from flytekit.core.tracker import TrackedInstance from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError from flytekit.core.utils import timeit @@ -122,7 +123,7 @@ class TaskMetadata(object): cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache. interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees that can include pre-emption. deprecated (str): Can be used to provide a warning message for a deprecated task. An absence or empty string indicates that the task is active and not deprecated. - retries (int): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times. + retries (Union[int, Retry]): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times. timeout (Optional[Union[datetime.timedelta, int]]): The maximum duration for which one execution of this task should run. The execution will be terminated if the runtime exceeds this timeout. pod_template_name (Optional[str]): The name of an existing PodTemplate resource in the cluster which will be used for this task. generates_deck (bool): Indicates whether the task will generate a Deck URI. @@ -135,7 +136,7 @@ class TaskMetadata(object): cache_ignore_input_vars: Tuple[str, ...] = () interruptible: Optional[bool] = None deprecated: str = "" - retries: int = 0 + retries: Union[int, Retry] = 0 timeout: Optional[Union[datetime.timedelta, int]] = None pod_template_name: Optional[str] = None generates_deck: bool = False @@ -158,7 +159,28 @@ def __post_init__(self): @property def retry_strategy(self) -> _literal_models.RetryStrategy: - return _literal_models.RetryStrategy(self.retries) + if isinstance(self.retries, int): + return _literal_models.RetryStrategy(retries=self.retries) + elif isinstance(self.retries, Retry): + if self.retries.on_oom is None: + on_oom = None + else: + if self.retries.on_oom.backoff is None: + backoff = None + else: + backoff = _literal_models.ExponentialBackoff( + exponent=self.retries.on_oom.backoff.exponent, + max=self.retries.on_oom.backoff.max, + ) + on_oom = _literal_models.RetryOnOOM( + factor=self.retries.on_oom.factor, + limit=self.retries.on_oom.limit, + backoff=backoff, + ) + return _literal_models.RetryStrategy( + self.retries.attempts, + on_oom=on_oom, + ) def to_taskmetadata_model(self) -> _task_model.TaskMetadata: """ diff --git a/flytekit/core/retry.py b/flytekit/core/retry.py new file mode 100644 index 0000000000..b33a13df28 --- /dev/null +++ b/flytekit/core/retry.py @@ -0,0 +1,72 @@ +import datetime +from dataclasses import dataclass, fields +from typing import TYPE_CHECKING, List, Optional, Union +from typing import Literal as L + +from flyteidl.core import tasks_pb2 + +if TYPE_CHECKING: + from kubernetes.client import V1PodSpec +from mashumaro.mixins.json import DataClassJSONMixin + +from flytekit.core.constants import SHARED_MEMORY_MOUNT_NAME, SHARED_MEMORY_MOUNT_PATH +from flytekit.extras.accelerators import BaseAccelerator +from flytekit.models import task as task_models + + +@dataclass +class Backoff(DataClassJSONMixin): + """ + This class is used to specify the backoff strategy for retries. + + .. code-block:: python + + Backoff(exponent=2, max=timedelta(minutes=2)) # This is a backoff strategy with an exponent of 2 and a max of 2 minutes + + """ + + exponent: int = 0 + max: Optional[datetime.timedelta] = None + + def __post_init__(self): + if self.exponent < 0: + raise ValueError("Exponent must be a non-negative integer.") + if self.max is not None and self.max.total_seconds() < 0: + raise ValueError("Max must be a non-negative timedelta.") + + +@dataclass +class OnOOM(DataClassJSONMixin): + """ + This class is used to specify the behavior when a task runs out of memory. + + .. code-block:: python + + OnOOM(backoff=Backoff(exponent=2, max=timedelta(minutes=2)), factor=1.0, limit="0") + """ + + factor: float = 1.2 + limit: str = "0" + backoff: Optional[Backoff] = None + + def __post_init__(self): + if self.factor <= 1.0: + raise ValueError("Factor must be a non-negative float.") + + +@dataclass +class Retry(DataClassJSONMixin): + """ + This class is used to specify the retry strategy for a task. + + .. code-block:: python + + Retry(attempts=3, on_oom=OnOOM(backoff=Backoff(exponent=2, max=timedelta(minutes=2)), factor=1.0, limit="0")) + """ + + attempts: int = 0 + on_oom: Optional[OnOOM] = None + + def __post_init__(self): + if self.attempts < 0: + raise ValueError("Attempts must be a non-negative integer.") diff --git a/flytekit/core/task.py b/flytekit/core/task.py index f39a133877..91511e09ab 100644 --- a/flytekit/core/task.py +++ b/flytekit/core/task.py @@ -18,6 +18,7 @@ from flytekit.core.python_function_task import AsyncPythonFunctionTask, EagerAsyncPythonFunctionTask, PythonFunctionTask from flytekit.core.reference_entity import ReferenceEntity, TaskReference from flytekit.core.resources import Resources +from flytekit.core.retry import Retry from flytekit.core.utils import str2bool from flytekit.deck import DeckField from flytekit.extras.accelerators import BaseAccelerator @@ -99,7 +100,7 @@ def task( _task_function: None = ..., task_config: Optional[T] = ..., cache: Union[bool, Cache] = ..., - retries: int = ..., + retries: Union[int, Retry] = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., timeout: Union[datetime.timedelta, int] = ..., @@ -137,7 +138,7 @@ def task( _task_function: Callable[P, FuncOut], task_config: Optional[T] = ..., cache: Union[bool, Cache] = ..., - retries: int = ..., + retries: Union[int, Retry] = ..., interruptible: Optional[bool] = ..., deprecated: str = ..., timeout: Union[datetime.timedelta, int] = ..., @@ -174,7 +175,7 @@ def task( _task_function: Optional[Callable[P, FuncOut]] = None, task_config: Optional[T] = None, cache: Union[bool, Cache] = False, - retries: int = 0, + retries: Union[int, Retry] = 0, interruptible: Optional[bool] = None, deprecated: str = "", timeout: Union[datetime.timedelta, int] = 0, diff --git a/flytekit/models/literals.py b/flytekit/models/literals.py index a4b5a1d359..7e2031de0b 100644 --- a/flytekit/models/literals.py +++ b/flytekit/models/literals.py @@ -15,13 +15,104 @@ from flytekit.models.types import SchemaType as _SchemaType +class ExponentialBackoff(_common.FlyteIdlEntity): + def __init__(self, exponent: int, max: Optional[_timedelta]): + self._exponent = exponent + self._max = max + + @property + def exponent(self): + """ + :rtype: int + """ + return self._exponent + + @property + def max(self): + """ + :rtype: datetime.timedelta + """ + return self._max + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.literals_pb2.ExponentialBackoff + """ + backoff = _literals_pb2.ExponentialBackoff(max_exponent=self.exponent) + if self.max is not None: + backoff.max.FromTimedelta(self.max) + return backoff + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.literals_pb2.ExponentialBackoff pb2_object: + :rtype: ExponentialBackoff + """ + return cls( + exponent=pb2_object.max_exponent, + max=pb2_object.max.ToTimedelta() if pb2_object.HasField("max") else None, + ) + + +class RetryOnOOM(_common.FlyteIdlEntity): + def __init__(self, factor: float, limit: str, backoff: Optional[ExponentialBackoff] = None): + self._factor = factor + self._limit = limit + self._backoff = backoff + + @property + def factor(self): + """ + :rtype: float + """ + return self._factor + + @property + def limit(self): + """ + :rtype: Text + """ + return self._limit + + @property + def backoff(self): + """ + :rtype: ExponentialBackoff + """ + return self._backoff + + def to_flyte_idl(self): + """ + :rtype: flyteidl.core.literals_pb2.RetryOnOOM + """ + return _literals_pb2.RetryOnOOM( + factor=self.factor, + limit=self.limit, + backoff=self.backoff.to_flyte_idl() if self.backoff is not None else None, + ) + + @classmethod + def from_flyte_idl(cls, pb2_object): + """ + :param flyteidl.core.literals_pb2.RetryOnOOM pb2_object: + :rtype: RetryOnOOM + """ + return cls( + factor=pb2_object.factor, + limit=pb2_object.limit, + backoff=ExponentialBackoff.from_flyte_idl(pb2_object.backoff) if pb2_object.HasField("backoff") else None, + ) + + class RetryStrategy(_common.FlyteIdlEntity): - def __init__(self, retries: int): + def __init__(self, retries: int, on_oom: Optional[RetryOnOOM] = None): """ :param int retries: Number of retries to attempt on recoverable failures. If retries is 0, then only one attempt will be made. """ self._retries = retries + self._on_oom = on_oom @property def retries(self): @@ -31,11 +122,21 @@ def retries(self): """ return self._retries + @property + def on_oom(self): + """ + :rtype: RetryOnOOM + """ + return self._on_oom + def to_flyte_idl(self): """ :rtype: flyteidl.core.literals_pb2.RetryStrategy """ - return _literals_pb2.RetryStrategy(retries=self.retries) + return _literals_pb2.RetryStrategy( + retries=self.retries, + on_oom=self.on_oom.to_flyte_idl() if self.on_oom is not None else None, + ) @classmethod def from_flyte_idl(cls, pb2_object): @@ -43,7 +144,10 @@ def from_flyte_idl(cls, pb2_object): :param flyteidl.core.literals_pb2.RetryStrategy pb2_object: :rtype: RetryStrategy """ - return cls(retries=pb2_object.retries) + return cls( + retries=pb2_object.retries, + on_oom=RetryOnOOM.from_flyte_idl(pb2_object.on_oom) if pb2_object.HasField("on_oom") else None, + ) class Primitive(_common.FlyteIdlEntity):