Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] [Feat-v2] Enable memory increase on OOM failure #3164

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions flytekit/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -243,6 +243,7 @@
from flytekit.core.reference import get_reference_entity
from flytekit.core.reference_entity import LaunchPlanReference, TaskReference, WorkflowReference
from flytekit.core.resources import Resources
from flytekit.core.retry import Backoff, OnOOM, Retry
from flytekit.core.schedule import CronSchedule, FixedRate
from flytekit.core.task import Secret, eager, reference_task, task
from flytekit.core.type_engine import BatchSize
Expand Down
28 changes: 25 additions & 3 deletions flytekit/core/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
flyte_entity_call_handler,
translate_inputs_to_literals,
)
from flytekit.core.retry import Retry
from flytekit.core.tracker import TrackedInstance
from flytekit.core.type_engine import TypeEngine, TypeTransformerFailedError
from flytekit.core.utils import timeit
Expand Down Expand Up @@ -122,7 +123,7 @@ class TaskMetadata(object):
cache_ignore_input_vars (Tuple[str, ...]): Input variables that should not be included when calculating hash for cache.
interruptible (Optional[bool]): Indicates that this task can be interrupted and/or scheduled on nodes with lower QoS guarantees that can include pre-emption.
deprecated (str): Can be used to provide a warning message for a deprecated task. An absence or empty string indicates that the task is active and not deprecated.
retries (int): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times.
retries (Union[int, Retry]): for retries=n; n > 0, on failures of this task, the task will be retried at-least n number of times.
timeout (Optional[Union[datetime.timedelta, int]]): The maximum duration for which one execution of this task should run. The execution will be terminated if the runtime exceeds this timeout.
pod_template_name (Optional[str]): The name of an existing PodTemplate resource in the cluster which will be used for this task.
generates_deck (bool): Indicates whether the task will generate a Deck URI.
Expand All @@ -135,7 +136,7 @@ class TaskMetadata(object):
cache_ignore_input_vars: Tuple[str, ...] = ()
interruptible: Optional[bool] = None
deprecated: str = ""
retries: int = 0
retries: Union[int, Retry] = 0
timeout: Optional[Union[datetime.timedelta, int]] = None
pod_template_name: Optional[str] = None
generates_deck: bool = False
Expand All @@ -158,7 +159,28 @@ def __post_init__(self):

@property
def retry_strategy(self) -> _literal_models.RetryStrategy:
return _literal_models.RetryStrategy(self.retries)
if isinstance(self.retries, int):
return _literal_models.RetryStrategy(retries=self.retries)
elif isinstance(self.retries, Retry):
if self.retries.on_oom is None:
on_oom = None
else:
if self.retries.on_oom.backoff is None:
backoff = None
else:
backoff = _literal_models.ExponentialBackoff(
exponent=self.retries.on_oom.backoff.exponent,
max=self.retries.on_oom.backoff.max,
)
on_oom = _literal_models.RetryOnOOM(
factor=self.retries.on_oom.factor,
limit=self.retries.on_oom.limit,
backoff=backoff,
)
return _literal_models.RetryStrategy(
self.retries.attempts,
on_oom=on_oom,
)

def to_taskmetadata_model(self) -> _task_model.TaskMetadata:
"""
Expand Down
72 changes: 72 additions & 0 deletions flytekit/core/retry.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
import datetime
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, List, Optional, Union
from typing import Literal as L

from flyteidl.core import tasks_pb2

if TYPE_CHECKING:
from kubernetes.client import V1PodSpec
from mashumaro.mixins.json import DataClassJSONMixin

from flytekit.core.constants import SHARED_MEMORY_MOUNT_NAME, SHARED_MEMORY_MOUNT_PATH
from flytekit.extras.accelerators import BaseAccelerator
from flytekit.models import task as task_models


@dataclass
class Backoff(DataClassJSONMixin):
"""
This class is used to specify the backoff strategy for retries.

.. code-block:: python

Backoff(exponent=2, max=timedelta(minutes=2)) # This is a backoff strategy with an exponent of 2 and a max of 2 minutes

"""

exponent: int = 0
max: Optional[datetime.timedelta] = None

def __post_init__(self):
if self.exponent < 0:
raise ValueError("Exponent must be a non-negative integer.")
if self.max is not None and self.max.total_seconds() < 0:
raise ValueError("Max must be a non-negative timedelta.")


@dataclass
class OnOOM(DataClassJSONMixin):
"""
This class is used to specify the behavior when a task runs out of memory.

.. code-block:: python

OnOOM(backoff=Backoff(exponent=2, max=timedelta(minutes=2)), factor=1.0, limit="0")
"""

factor: float = 1.2
limit: str = "0"
backoff: Optional[Backoff] = None

def __post_init__(self):
if self.factor <= 1.0:
raise ValueError("Factor must be a non-negative float.")


@dataclass
class Retry(DataClassJSONMixin):
"""
This class is used to specify the retry strategy for a task.

.. code-block:: python

Retry(attempts=3, on_oom=OnOOM(backoff=Backoff(exponent=2, max=timedelta(minutes=2)), factor=1.0, limit="0"))
"""

attempts: int = 0
on_oom: Optional[OnOOM] = None

def __post_init__(self):
if self.attempts < 0:
raise ValueError("Attempts must be a non-negative integer.")
7 changes: 4 additions & 3 deletions flytekit/core/task.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from flytekit.core.python_function_task import AsyncPythonFunctionTask, EagerAsyncPythonFunctionTask, PythonFunctionTask
from flytekit.core.reference_entity import ReferenceEntity, TaskReference
from flytekit.core.resources import Resources
from flytekit.core.retry import Retry
from flytekit.core.utils import str2bool
from flytekit.deck import DeckField
from flytekit.extras.accelerators import BaseAccelerator
Expand Down Expand Up @@ -99,7 +100,7 @@ def task(
_task_function: None = ...,
task_config: Optional[T] = ...,
cache: Union[bool, Cache] = ...,
retries: int = ...,
retries: Union[int, Retry] = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[datetime.timedelta, int] = ...,
Expand Down Expand Up @@ -137,7 +138,7 @@ def task(
_task_function: Callable[P, FuncOut],
task_config: Optional[T] = ...,
cache: Union[bool, Cache] = ...,
retries: int = ...,
retries: Union[int, Retry] = ...,
interruptible: Optional[bool] = ...,
deprecated: str = ...,
timeout: Union[datetime.timedelta, int] = ...,
Expand Down Expand Up @@ -174,7 +175,7 @@ def task(
_task_function: Optional[Callable[P, FuncOut]] = None,
task_config: Optional[T] = None,
cache: Union[bool, Cache] = False,
retries: int = 0,
retries: Union[int, Retry] = 0,
interruptible: Optional[bool] = None,
deprecated: str = "",
timeout: Union[datetime.timedelta, int] = 0,
Expand Down
110 changes: 107 additions & 3 deletions flytekit/models/literals.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,13 +15,104 @@
from flytekit.models.types import SchemaType as _SchemaType


class ExponentialBackoff(_common.FlyteIdlEntity):
def __init__(self, exponent: int, max: Optional[_timedelta]):
self._exponent = exponent
self._max = max

@property
def exponent(self):
"""
:rtype: int
"""
return self._exponent

@property
def max(self):
"""
:rtype: datetime.timedelta
"""
return self._max

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.ExponentialBackoff
"""
backoff = _literals_pb2.ExponentialBackoff(max_exponent=self.exponent)
if self.max is not None:
backoff.max.FromTimedelta(self.max)
return backoff

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.literals_pb2.ExponentialBackoff pb2_object:
:rtype: ExponentialBackoff
"""
return cls(
exponent=pb2_object.max_exponent,
max=pb2_object.max.ToTimedelta() if pb2_object.HasField("max") else None,
)


class RetryOnOOM(_common.FlyteIdlEntity):
def __init__(self, factor: float, limit: str, backoff: Optional[ExponentialBackoff] = None):
self._factor = factor
self._limit = limit
self._backoff = backoff

@property
def factor(self):
"""
:rtype: float
"""
return self._factor

@property
def limit(self):
"""
:rtype: Text
"""
return self._limit

@property
def backoff(self):
"""
:rtype: ExponentialBackoff
"""
return self._backoff

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.RetryOnOOM
"""
return _literals_pb2.RetryOnOOM(
factor=self.factor,
limit=self.limit,
backoff=self.backoff.to_flyte_idl() if self.backoff is not None else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.literals_pb2.RetryOnOOM pb2_object:
:rtype: RetryOnOOM
"""
return cls(
factor=pb2_object.factor,
limit=pb2_object.limit,
backoff=ExponentialBackoff.from_flyte_idl(pb2_object.backoff) if pb2_object.HasField("backoff") else None,
)


class RetryStrategy(_common.FlyteIdlEntity):
def __init__(self, retries: int):
def __init__(self, retries: int, on_oom: Optional[RetryOnOOM] = None):
"""
:param int retries: Number of retries to attempt on recoverable failures. If retries is 0, then
only one attempt will be made.
"""
self._retries = retries
self._on_oom = on_oom

@property
def retries(self):
Expand All @@ -31,19 +122,32 @@ def retries(self):
"""
return self._retries

@property
def on_oom(self):
"""
:rtype: RetryOnOOM
"""
return self._on_oom

def to_flyte_idl(self):
"""
:rtype: flyteidl.core.literals_pb2.RetryStrategy
"""
return _literals_pb2.RetryStrategy(retries=self.retries)
return _literals_pb2.RetryStrategy(
retries=self.retries,
on_oom=self.on_oom.to_flyte_idl() if self.on_oom is not None else None,
)

@classmethod
def from_flyte_idl(cls, pb2_object):
"""
:param flyteidl.core.literals_pb2.RetryStrategy pb2_object:
:rtype: RetryStrategy
"""
return cls(retries=pb2_object.retries)
return cls(
retries=pb2_object.retries,
on_oom=RetryOnOOM.from_flyte_idl(pb2_object.on_oom) if pb2_object.HasField("on_oom") else None,
)


class Primitive(_common.FlyteIdlEntity):
Expand Down
Loading