diff --git a/smartsim/_core/utils/helpers.py b/smartsim/_core/utils/helpers.py index 1133358a6..fef1e792f 100644 --- a/smartsim/_core/utils/helpers.py +++ b/smartsim/_core/utils/helpers.py @@ -43,11 +43,15 @@ from datetime import datetime from shutil import which +from typing_extensions import TypeAlias + if t.TYPE_CHECKING: from types import FrameType from typing_extensions import TypeVarTuple, Unpack + from smartsim.launchable.job import Job + _Ts = TypeVarTuple("_Ts") @@ -55,6 +59,23 @@ _HashableT = t.TypeVar("_HashableT", bound=t.Hashable) _TSignalHandlerFn = t.Callable[[int, t.Optional["FrameType"]], object] +_NestedJobSequenceType: TypeAlias = "t.Sequence[Job | _NestedJobSequenceType]" + + +def unpack(value: _NestedJobSequenceType) -> t.Generator[Job, None, None]: + """Unpack any iterable input in order to obtain a + single sequence of values + + :param value: Sequence containing elements of type Job or other + sequences that are also of type _NestedJobSequenceType + :return: flattened list of Jobs""" + + for item in value: + if isinstance(item, t.Iterable): + yield from unpack(item) + else: + yield item + def check_name(name: str) -> None: """ diff --git a/smartsim/experiment.py b/smartsim/experiment.py index fef046475..cda91c4d1 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -151,7 +151,7 @@ def __init__(self, name: str, exp_path: str | None = None): experiment """ - def start(self, *jobs: Job) -> tuple[LaunchedJobID, ...]: + def start(self, *jobs: Job | t.Sequence[Job]) -> tuple[LaunchedJobID, ...]: """Execute a collection of `Job` instances. :param jobs: A collection of other job instances to start @@ -159,11 +159,10 @@ def start(self, *jobs: Job) -> tuple[LaunchedJobID, ...]: jobs that can be used to query or alter the status of that particular execution of the job. """ - # Create the run id + jobs_ = list(_helpers.unpack(jobs)) run_id = datetime.datetime.now().replace(microsecond=0).isoformat() - # Generate the root path root = pathlib.Path(self.exp_path, run_id) - return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs) + return self._dispatch(Generator(root), dispatch.DEFAULT_DISPATCHER, *jobs_) def _dispatch( self, diff --git a/tests/_legacy/test_helpers.py b/tests/_legacy/test_helpers.py index 523ed7191..7b453905c 100644 --- a/tests/_legacy/test_helpers.py +++ b/tests/_legacy/test_helpers.py @@ -30,12 +30,32 @@ import pytest from smartsim._core.utils import helpers -from smartsim._core.utils.helpers import cat_arg_and_value +from smartsim._core.utils.helpers import cat_arg_and_value, unpack +from smartsim.entity.application import Application +from smartsim.launchable.job import Job +from smartsim.settings.launch_settings import LaunchSettings # The tests in this file belong to the group_a group pytestmark = pytest.mark.group_a +def test_unpack_iterates_over_nested_jobs_in_expected_order(wlmutils): + launch_settings = LaunchSettings(wlmutils.get_test_launcher()) + app = Application("app_name", exe="python") + job_1 = Job(app, launch_settings) + job_2 = Job(app, launch_settings) + job_3 = Job(app, launch_settings) + job_4 = Job(app, launch_settings) + job_5 = Job(app, launch_settings) + + assert ( + [job_1, job_2, job_3, job_4, job_5] + == list(unpack([job_1, [job_2, job_3], job_4, [job_5]])) + == list(unpack([job_1, job_2, [job_3, job_4], job_5])) + == list(unpack([job_1, [job_2, [job_3, job_4], job_5]])) + ) + + def test_double_dash_concat(): result = cat_arg_and_value("--foo", "FOO") assert result == "--foo=FOO" diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 73657801d..d88abeb20 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -34,23 +34,32 @@ import time import typing as t import uuid +from os import path as osp import pytest from smartsim._core import dispatch -from smartsim._core.control.interval import SynchronousTimeInterval from smartsim._core.control.launch_history import LaunchHistory from smartsim._core.utils.launcher import LauncherProtocol, create_job_id +from smartsim.builders.ensemble import Ensemble from smartsim.entity import entity +from smartsim.entity.application import Application from smartsim.error import errors from smartsim.experiment import Experiment from smartsim.launchable import job from smartsim.settings import launch_settings from smartsim.settings.arguments import launch_arguments from smartsim.status import InvalidJobStatus, JobStatus +from smartsim.types import LaunchedJobID pytestmark = pytest.mark.group_a +_ID_GENERATOR = (str(i) for i in itertools.count()) + + +def random_id(): + return next(_ID_GENERATOR) + @pytest.fixture def experiment(monkeypatch, test_dir, dispatcher): @@ -611,3 +620,111 @@ def test_experiment_stop_does_not_raise_on_unknown_job_id( assert stat == InvalidJobStatus.NEVER_STARTED after_cancel = exp.get_status(*all_known_ids) assert before_cancel == after_cancel + + +@pytest.mark.parametrize( + "job_list", + ( + pytest.param( + [ + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + ) + ], + id="(job1, (job2, job_3))", + ), + pytest.param( + [ + ( + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + job.Job( + Application( + "test_name_2", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ), + ) + ], + id="((job1, job2), (job3, job4))", + ), + pytest.param( + [ + ( + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ) + ], + id="(job,)", + ), + pytest.param( + [ + [ + job.Job( + Application( + "test_name", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ( + Ensemble("ensemble-name", "echo", replicas=2).build_jobs( + launch_settings.LaunchSettings("local") + ), + job.Job( + Application( + "test_name_2", + exe="echo", + exe_args=["spam", "eggs"], + ), + launch_settings.LaunchSettings("local"), + ), + ), + ] + ], + id="[job_1, ((job_2, job_3), job_4)]", + ), + ), +) +def test_start_unpack( + test_dir: str, wlmutils, monkeypatch: pytest.MonkeyPatch, job_list: job.Job +): + """Test unpacking a sequences of jobs""" + + monkeypatch.setattr( + "smartsim._core.dispatch._LauncherAdapter.start", + lambda launch, exe, job_execution_path, env, out, err: random_id(), + ) + + exp = Experiment(name="exp_name", exp_path=test_dir) + exp.start(*job_list)