From c6029a4c3092653d230144a31d6263c949c3fb22 Mon Sep 17 00:00:00 2001 From: DeepMind Team Date: Fri, 28 Jan 2022 02:49:59 -0800 Subject: [PATCH] change _create_experiment_unit signature to return Awaitable to allow for async calls PiperOrigin-RevId: 424823349 Change-Id: I123b056b7be1b69b1424cd23c97f06b2e8b5c9ff GitOrigin-RevId: 0c82144864fb7d89f01d1d5dcfcb5cc0a4da01ae --- xmanager/xm/core.py | 22 +++++++++++++++++++--- xmanager/xm/testing.py | 9 ++++++--- xmanager/xm_local/experiment.py | 18 ++++++++++++------ 3 files changed, 37 insertions(+), 12 deletions(-) diff --git a/xmanager/xm/core.py b/xmanager/xm/core.py index 6a6e3a9..9a8fe2e 100644 --- a/xmanager/xm/core.py +++ b/xmanager/xm/core.py @@ -610,17 +610,33 @@ def add(self, job, args=None, role=WorkUnitRole()): An awaitable that would be fulfilled when the job is launched. """ # pyformat: enable - experiment_unit = self._create_experiment_unit(args, role) + experiment_unit_future = self._create_experiment_unit(args, role) async def launch(): + experiment_unit = await experiment_unit_future await experiment_unit.add(job, args) return experiment_unit return asyncio.wrap_future(self._create_task(launch())) @abc.abstractmethod - def _create_experiment_unit(self, args: Optional[Mapping[str, Any]], - role: ExperimentUnitRole) -> ExperimentUnit: + def _create_experiment_unit( + self, args: Optional[Mapping[str, Any]], + role: ExperimentUnitRole) -> Awaitable[ExperimentUnit]: + """Creates a new experiment unit. + + Synchronously starts the experiment unit creation, ensuring that IDs would + be assigned in invocation order. The operation itself may run asynchronously + in background. + + Args: + args: Executable unit arguments, to be show as a part of hyper-parameter + sweep. + role: Executable unit role: whether to create a work or auxiliary unit. + + Returns: + An awaitable to the creation result. + """ raise NotImplementedError def _create_task(self, task: Awaitable[Any]) -> futures.Future: diff --git a/xmanager/xm/testing.py b/xmanager/xm/testing.py index 9f62548..08639f2 100644 --- a/xmanager/xm/testing.py +++ b/xmanager/xm/testing.py @@ -13,6 +13,7 @@ # limitations under the License. """Utilities for testing core objects.""" +import asyncio from concurrent import futures from typing import Any, Awaitable, Callable, List, Mapping, Optional @@ -68,14 +69,16 @@ def __init__(self) -> None: self.launched_jobs_args = [] self._work_units = [] - def _create_experiment_unit(self, args, - role=core.WorkUnitRole()) -> TestWorkUnit: + def _create_experiment_unit( + self, args, role=core.WorkUnitRole()) -> Awaitable[TestWorkUnit]: """Creates a new WorkUnit instance for the experiment.""" + future = asyncio.Future() work_unit = TestWorkUnit(self, self._work_unit_id_predictor, self._create_task, self.launched_jobs, self.launched_jobs_args, args) self._work_units.append(work_unit) - return work_unit + future.set_result(work_unit) + return future @property def work_unit_count(self) -> int: diff --git a/xmanager/xm_local/experiment.py b/xmanager/xm_local/experiment.py index 141d5c8..a6c17ce 100644 --- a/xmanager/xm_local/experiment.py +++ b/xmanager/xm_local/experiment.py @@ -249,11 +249,12 @@ def __init__(self, experiment_title: str) -> None: self._experiment_units = [] self._work_unit_count = 0 - def _create_experiment_unit(self, args: Optional[Mapping[str, Any]], - role: xm.ExperimentUnitRole) -> xm.ExperimentUnit: + def _create_experiment_unit( + self, args: Optional[Mapping[str, Any]], + role: xm.ExperimentUnitRole) -> Awaitable[xm.ExperimentUnit]: """Creates a new WorkUnit instance for the experiment.""" - def create_work_unit(role: xm.WorkUnitRole): + def create_work_unit(role: xm.WorkUnitRole) -> Awaitable[xm.ExperimentUnit]: work_unit = LocalWorkUnit( self, self._experiment_title, @@ -268,10 +269,13 @@ def create_work_unit(role: xm.WorkUnitRole): self.experiment_id, work_unit.work_unit_id, ) - return work_unit + future = asyncio.Future() + future.set_result(work_unit) + return future # TODO: Support `role.termination_delay_secs`. - def create_auxiliary_unit(role: xm.AuxiliaryUnitRole): + def create_auxiliary_unit( + role: xm.AuxiliaryUnitRole) -> Awaitable[xm.ExperimentUnit]: auxiliary_unit = LocalAuxiliaryUnit( self, self._experiment_title, @@ -280,7 +284,9 @@ def create_auxiliary_unit(role: xm.AuxiliaryUnitRole): role, ) self._experiment_units.append(auxiliary_unit) - return auxiliary_unit + future = asyncio.Future() + future.set_result(auxiliary_unit) + return future return pattern_matching.match(create_work_unit, create_auxiliary_unit)(role)