diff --git a/smartsim/_core/control/interval.py b/smartsim/_core/control/interval.py index cdcfd47cb..c4e31c5d2 100644 --- a/smartsim/_core/control/interval.py +++ b/smartsim/_core/control/interval.py @@ -102,7 +102,7 @@ def new_interval(self) -> SynchronousTimeInterval: """ return type(self)(self.delta) - def wait(self) -> None: + def block(self) -> None: """Block the thread until the timeout completes :raises RuntimeError: The thread would be blocked forever diff --git a/smartsim/experiment.py b/smartsim/experiment.py index 3fa5d12b3..fcdade5c9 100644 --- a/smartsim/experiment.py +++ b/smartsim/experiment.py @@ -290,18 +290,15 @@ def wait( if not ids: raise ValueError("No job ids to wait on provided") self._poll_for_statuses( - ids, - TERMINAL_STATUSES, - timeout=_interval.SynchronousTimeInterval(timeout), - verbose=verbose, + ids, TERMINAL_STATUSES, timeout=timeout, verbose=verbose ) def _poll_for_statuses( self, ids: t.Sequence[LaunchedJobID], statuses: t.Collection[JobStatus], - timeout: _interval.SynchronousTimeInterval | None = None, - interval: _interval.SynchronousTimeInterval | None = None, + timeout: float | None = None, + interval: float = 5.0, verbose: bool = True, ) -> dict[LaunchedJobID, JobStatus | InvalidJobStatus]: """Poll the experiment's launchers for the statuses of the launched @@ -322,8 +319,8 @@ def _poll_for_statuses( """ terminal = frozenset(itertools.chain(statuses, InvalidJobStatus)) log = logger.info if verbose else lambda *_, **__: None - method_timeout = timeout or _interval.SynchronousTimeInterval(None) - iter_timeout = interval or _interval.SynchronousTimeInterval(5.0) + method_timeout = _interval.SynchronousTimeInterval(timeout) + iter_timeout = _interval.SynchronousTimeInterval(interval) final: dict[LaunchedJobID, JobStatus | InvalidJobStatus] = {} def is_finished( @@ -349,7 +346,7 @@ def is_finished( iter_timeout if iter_timeout.remaining < method_timeout.remaining else method_timeout - ).wait() + ).block() if ids: raise TimeoutError( f"Job ID(s) {', '.join(map(str, ids))} failed to reach " diff --git a/tests/test_experiment.py b/tests/test_experiment.py index 795adf542..29e2626cc 100644 --- a/tests/test_experiment.py +++ b/tests/test_experiment.py @@ -471,11 +471,7 @@ def __call__(self, *args, **kwargs): "smartsim.experiment.logger.info", lambda s: mock_log.write(f"{s}\n") ) final_statuses = exp._poll_for_statuses( - [id_], - different_statuses, - timeout=SynchronousTimeInterval(10), - interval=SynchronousTimeInterval(0), - verbose=verbose, + [id_], different_statuses, timeout=10, interval=0, verbose=verbose ) assert final_statuses == {id_: new_status} @@ -501,8 +497,8 @@ def test_poll_status_raises_when_called_with_infinite_iter_wait( exp._poll_for_statuses( [id_], [], - timeout=SynchronousTimeInterval(10), - interval=SynchronousTimeInterval(None), + timeout=10, + interval=float("inf"), ) @@ -523,6 +519,6 @@ def test_poll_for_status_raises_if_ids_not_found_within_timeout( exp._poll_for_statuses( [id_], different_statuses, - timeout=SynchronousTimeInterval(1), - interval=SynchronousTimeInterval(0), + timeout=1, + interval=0, ) diff --git a/tests/test_intervals.py b/tests/test_intervals.py index a98bfa22e..1b865867f 100644 --- a/tests/test_intervals.py +++ b/tests/test_intervals.py @@ -65,7 +65,7 @@ def test_sync_timeout_can_block_thread(): """Test that the sync timeout can block the calling thread""" timeout = 1 now = time.perf_counter() - SynchronousTimeInterval(timeout).wait() + SynchronousTimeInterval(timeout).block() later = time.perf_counter() assert abs(later - now - timeout) <= 0.25 @@ -78,17 +78,10 @@ def test_sync_timeout_infinte(): assert t.remaining == float("inf") assert t.infinite with pytest.raises(RuntimeError, match="block thread forever"): - t.wait() + t.block() def test_sync_timeout_raises_on_invalid_value(monkeypatch): """Cannot make a sync time interval with a negative time delta""" with pytest.raises(ValueError): - t = SynchronousTimeInterval(-1) - now = time.perf_counter() - monkeypatch.setattr( - time, "perf_counter", lambda *_, **__: now + 365 * 24 * 60 * 60 - ) - assert t._delta == float("inf") - assert t.expired == False - assert t.remaining == float("inf") + SynchronousTimeInterval(-1)