diff --git a/CHANGELOG.md b/CHANGELOG.md index d847d10..6b328cf 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -9,6 +9,10 @@ use patch releases for compatibility fixes instead. ## Unreleased +### Added + +- Added support for passing human-readable `synchronized_start_timeout` duration strings to `TaskSpec.new()`, like "10sec", "1m", etc. + ## [v1.26.9](https://github.com/allenai/beaker-py/releases/tag/v1.26.9) - 2024-05-02 ### Fixed diff --git a/beaker/data_model/experiment_spec.py b/beaker/data_model/experiment_spec.py index 2357326..7674956 100644 --- a/beaker/data_model/experiment_spec.py +++ b/beaker/data_model/experiment_spec.py @@ -4,6 +4,7 @@ from ..aliases import PathOrStr from ..exceptions import * +from ..util import parse_duration from .base import BaseModel, StrEnum, field_validator, model_validator __all__ = [ @@ -472,10 +473,11 @@ def new( else: constraints = Constraints(cluster=[cluster]) - # Allow setting the timeout using seconds, rather than nanoseconds. + # Allow setting the timeout as a string rather than nanoseconds, and assume a string + # without units means seconds. synchronized_start_timeout_str = kwargs.pop("synchronized_start_timeout", None) if synchronized_start_timeout_str is not None: - synchronized_start_timeout = int(synchronized_start_timeout_str * 1_000_000_000) + synchronized_start_timeout = parse_duration(synchronized_start_timeout_str) kwargs["synchronized_start_timeout"] = synchronized_start_timeout return TaskSpec( diff --git a/beaker/util.py b/beaker/util.py index af08c79..41578f5 100644 --- a/beaker/util.py +++ b/beaker/util.py @@ -117,6 +117,46 @@ def format_since(since: Union[datetime, timedelta, str]) -> str: return since +def parse_duration(dur: str) -> int: + """ + Parse a duration string into nanoseconds. + """ + dur_normalized = dur.replace(" ", "").lower() + match = re.match(r"^([0-9.e-]+)([a-z]*)$", dur_normalized) + if not match: + raise ValueError(f"invalid duration string '{dur}'") + + value_str, unit = match.group(1), match.group(2) + try: + value = float(value_str) + except ValueError: + raise ValueError(f"invalid duration string '{dur}'") + + if not unit: + # assume seconds + unit = "s" + + if unit in ("ns", "nanosecond", "nanoseconds"): + # nanoseconds + return int(value) + elif unit in ("µs", "microsecond", "microseconds"): + return int(value * 1_000) + elif unit in ("ms", "millisecond", "milliseconds"): + # milliseconds + return int(value * 1_000_000) + elif unit in ("s", "sec", "second", "seconds"): + # seconds + return int(value * 1_000_000_000) + elif unit in ("m", "min", "minute", "minutes"): + # minutes + return int(value * 60_000_000_000) + elif unit in ("h", "hr", "hour", "hours"): + # hours + return int(value * 3_600_000_000_000) + else: + raise ValueError(f"invalid duration string '{dur}'") + + TIMESTAMP_RE = re.compile(rb"^([0-9]{4}-[0-9]{2}-[0-9]{2}T[0-9]{2}:[0-9]{2}:[0-9]{2}.[0-9]+Z)(.*)$") diff --git a/tests/util_test.py b/tests/util_test.py index 5db83fe..f7b136b 100644 --- a/tests/util_test.py +++ b/tests/util_test.py @@ -49,3 +49,11 @@ def test_format_cursor(): cursor = 100 formatted = format_cursor(100) assert int.from_bytes(base64.urlsafe_b64decode(formatted), "little") == cursor + + +def test_parse_duration(): + assert parse_duration("1") == 1_000_000_000 + assert parse_duration("1s") == 1_000_000_000 + assert parse_duration("1sec") == 1_000_000_000 + assert parse_duration("1m") == 60 * 1_000_000_000 + assert parse_duration("1h") == 60 * 60 * 1_000_000_000