From 77981af04bee42d847ec2a1996a835dc602b339d Mon Sep 17 00:00:00 2001 From: epwalsh Date: Sun, 6 Oct 2024 16:34:17 -0700 Subject: [PATCH] Fix --- CHANGELOG.md | 2 +- beaker/data_model/experiment_spec.py | 52 ++++++++++++++-------------- 2 files changed, 27 insertions(+), 27 deletions(-) diff --git a/CHANGELOG.md b/CHANGELOG.md index 397ed28..9089414 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -11,7 +11,7 @@ use patch releases for compatibility fixes instead. ### Added -- Added `retry` field to `TaskSpec`. +- Added `retry` field to `ExperimentSpec`. ## [v1.31.3](https://github.com/allenai/beaker-py/releases/tag/v1.31.3) - 2024-08-30 diff --git a/beaker/data_model/experiment_spec.py b/beaker/data_model/experiment_spec.py index 74241e2..0a2d216 100644 --- a/beaker/data_model/experiment_spec.py +++ b/beaker/data_model/experiment_spec.py @@ -16,9 +16,9 @@ "TaskResources", "Priority", "TaskContext", - "TaskRetrySpec", "TaskSpec", "SpecVersion", + "RetrySpec", "ExperimentSpec", "Constraints", ] @@ -320,18 +320,6 @@ def __setitem__(self, key: str, val: List[Any]) -> None: setattr(self, key, val) -class TaskRetrySpec(BaseModel, frozen=False): - """ - Defines the retry behavior of a task. - """ - - allowed_task_retries: int - """ - A positive integer specifying the maximum number of task retries allowed for the experiment, - with a max limit of 10. - """ - - class TaskSpec(BaseModel, frozen=False): """ A :class:`TaskSpec` defines a :class:`~beaker.data_model.experiment.Task` within an :class:`ExperimentSpec`. @@ -358,11 +346,6 @@ class TaskSpec(BaseModel, frozen=False): Context describes how and where this task should run. """ - retry: Optional[TaskRetrySpec] = None - """ - Defines the retry behavior of a task. - """ - constraints: Optional[Constraints] = None """ Each task can have many constraints. And each constraint can have many values. @@ -577,14 +560,6 @@ def with_context(self, **kwargs) -> "TaskSpec": """ return self.model_copy(deep=True, update={"context": TaskContext(**kwargs)}) - def with_retries(self, allowed_task_retries: int) -> "TaskSpec": - """ - Return a new :class:`TaskSpec` with the given number of retries. - """ - return self.model_copy( - deep=True, update={"retries": TaskRetrySpec(allowed_task_retries=allowed_task_retries)} - ) - def with_name(self, name: str) -> "TaskSpec": """ Return a new :class:`TaskSpec` with the given :data:`name`. @@ -731,6 +706,18 @@ class SpecVersion(StrEnum): v2_alpha = "v2-alpha" +class RetrySpec(BaseModel, frozen=False): + """ + Defines the retry behavior of an experiment. + """ + + allowed_task_retries: int + """ + A positive integer specifying the maximum number of task retries allowed for the experiment, + with a max limit of 10. + """ + + class ExperimentSpec(BaseModel, frozen=False): """ Experiments are the main unit of execution in Beaker. @@ -775,6 +762,11 @@ class ExperimentSpec(BaseModel, frozen=False): Long-form explanation for an experiment. """ + retry: Optional[RetrySpec] = None + """ + Defines the retry behavior of an experiment. + """ + @field_validator("tasks") def _validate_tasks(cls, v: List[TaskSpec]) -> List[TaskSpec]: task_names = set() @@ -908,6 +900,14 @@ def with_description(self, description: str) -> "ExperimentSpec": """ return self.model_copy(deep=True, update={"description": description}) + def with_retries(self, allowed_task_retries: int) -> "TaskSpec": + """ + Return a new :class:`ExperimentSpec` with the given number of retries. + """ + return self.model_copy( + deep=True, update={"retry": RetrySpec(allowed_task_retries=allowed_task_retries)} + ) + def validate(self): for task in self.tasks: if (task.image.beaker is None) == (task.image.docker is None):