Skip to content

Commit

Permalink
change _create_experiment_unit signature to return Awaitable to allow…
Browse files Browse the repository at this point in the history
… for async calls

PiperOrigin-RevId: 424823349
Change-Id: I123b056b7be1b69b1424cd23c97f06b2e8b5c9ff
GitOrigin-RevId: 0c82144864fb7d89f01d1d5dcfcb5cc0a4da01ae
  • Loading branch information
DeepMind Team authored and andrewluchen committed Jan 28, 2022
1 parent 4613b38 commit c6029a4
Show file tree
Hide file tree
Showing 3 changed files with 37 additions and 12 deletions.
22 changes: 19 additions & 3 deletions xmanager/xm/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -610,17 +610,33 @@ def add(self, job, args=None, role=WorkUnitRole()):
An awaitable that would be fulfilled when the job is launched.
"""
# pyformat: enable
experiment_unit = self._create_experiment_unit(args, role)
experiment_unit_future = self._create_experiment_unit(args, role)

async def launch():
experiment_unit = await experiment_unit_future
await experiment_unit.add(job, args)
return experiment_unit

return asyncio.wrap_future(self._create_task(launch()))

@abc.abstractmethod
def _create_experiment_unit(self, args: Optional[Mapping[str, Any]],
role: ExperimentUnitRole) -> ExperimentUnit:
def _create_experiment_unit(
self, args: Optional[Mapping[str, Any]],
role: ExperimentUnitRole) -> Awaitable[ExperimentUnit]:
"""Creates a new experiment unit.
Synchronously starts the experiment unit creation, ensuring that IDs would
be assigned in invocation order. The operation itself may run asynchronously
in background.
Args:
args: Executable unit arguments, to be show as a part of hyper-parameter
sweep.
role: Executable unit role: whether to create a work or auxiliary unit.
Returns:
An awaitable to the creation result.
"""
raise NotImplementedError

def _create_task(self, task: Awaitable[Any]) -> futures.Future:
Expand Down
9 changes: 6 additions & 3 deletions xmanager/xm/testing.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
# limitations under the License.
"""Utilities for testing core objects."""

import asyncio
from concurrent import futures
from typing import Any, Awaitable, Callable, List, Mapping, Optional

Expand Down Expand Up @@ -68,14 +69,16 @@ def __init__(self) -> None:
self.launched_jobs_args = []
self._work_units = []

def _create_experiment_unit(self, args,
role=core.WorkUnitRole()) -> TestWorkUnit:
def _create_experiment_unit(
self, args, role=core.WorkUnitRole()) -> Awaitable[TestWorkUnit]:
"""Creates a new WorkUnit instance for the experiment."""
future = asyncio.Future()
work_unit = TestWorkUnit(self, self._work_unit_id_predictor,
self._create_task, self.launched_jobs,
self.launched_jobs_args, args)
self._work_units.append(work_unit)
return work_unit
future.set_result(work_unit)
return future

@property
def work_unit_count(self) -> int:
Expand Down
18 changes: 12 additions & 6 deletions xmanager/xm_local/experiment.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,11 +249,12 @@ def __init__(self, experiment_title: str) -> None:
self._experiment_units = []
self._work_unit_count = 0

def _create_experiment_unit(self, args: Optional[Mapping[str, Any]],
role: xm.ExperimentUnitRole) -> xm.ExperimentUnit:
def _create_experiment_unit(
self, args: Optional[Mapping[str, Any]],
role: xm.ExperimentUnitRole) -> Awaitable[xm.ExperimentUnit]:
"""Creates a new WorkUnit instance for the experiment."""

def create_work_unit(role: xm.WorkUnitRole):
def create_work_unit(role: xm.WorkUnitRole) -> Awaitable[xm.ExperimentUnit]:
work_unit = LocalWorkUnit(
self,
self._experiment_title,
Expand All @@ -268,10 +269,13 @@ def create_work_unit(role: xm.WorkUnitRole):
self.experiment_id,
work_unit.work_unit_id,
)
return work_unit
future = asyncio.Future()
future.set_result(work_unit)
return future

# TODO: Support `role.termination_delay_secs`.
def create_auxiliary_unit(role: xm.AuxiliaryUnitRole):
def create_auxiliary_unit(
role: xm.AuxiliaryUnitRole) -> Awaitable[xm.ExperimentUnit]:
auxiliary_unit = LocalAuxiliaryUnit(
self,
self._experiment_title,
Expand All @@ -280,7 +284,9 @@ def create_auxiliary_unit(role: xm.AuxiliaryUnitRole):
role,
)
self._experiment_units.append(auxiliary_unit)
return auxiliary_unit
future = asyncio.Future()
future.set_result(auxiliary_unit)
return future

return pattern_matching.match(create_work_unit, create_auxiliary_unit)(role)

Expand Down

0 comments on commit c6029a4

Please sign in to comment.