Skip to content

Commit

Permalink
Allow using torchx_ env vars to set scheduler params (#915) (#915)
Browse files Browse the repository at this point in the history
Summary:

Scheduler params currently can only be set through the programmatic API and not through this is not useful for cases like scheduling on mast rc cluster. This diff now lets you do that.

Reviewed By: andywag

Differential Revision: D57640022
  • Loading branch information
manav-a authored May 24, 2024
1 parent 05ddf23 commit d3393fc
Show file tree
Hide file tree
Showing 3 changed files with 48 additions and 33 deletions.
13 changes: 12 additions & 1 deletion torchx/runner/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -98,14 +98,25 @@ def __init__(
"""
self._name: str = name
self._scheduler_factories = scheduler_factories
self._scheduler_params: Dict[str, object] = scheduler_params or {}
self._scheduler_params: Dict[str, Any] = {
**(self._get_scheduler_params_from_env()),
**(scheduler_params or {}),
}
# pyre-fixme[24]: SchedulerOpts is a generic, and we don't have access to the corresponding type
self._scheduler_instances: Dict[str, Scheduler] = {}
self._apps: Dict[AppHandle, AppDef] = {}

# component_name -> map of component_fn_param_name -> user-specified default val encoded as str
self._component_defaults: Dict[str, Dict[str, str]] = component_defaults or {}

def _get_scheduler_params_from_env(self) -> Dict[str, str]:
scheduler_params = {}
for key, value in os.environ.items():
lower_case_key = key.lower()
if lower_case_key.startswith("torchx_"):
scheduler_params[lower_case_key.strip("torchx_")] = value
return scheduler_params

def __enter__(self) -> "Runner":
return self

Expand Down
65 changes: 34 additions & 31 deletions torchx/runner/test/api_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,15 @@
import datetime
import os
from contextlib import contextmanager
from typing import Generator, List, Mapping, Optional
from typing import cast, Generator, List, Mapping, Optional
from unittest.mock import MagicMock, patch

from torchx.runner import get_runner, Runner
from torchx.schedulers import SchedulerFactory
from torchx.schedulers.api import DescribeAppResponse, ListAppResponse, Scheduler
from torchx.schedulers.local_scheduler import (
create_scheduler,
LocalDirectoryImageProvider,
LocalScheduler,
)
from torchx.specs import AppDryRunInfo, CfgVal
from torchx.specs.api import (
Expand Down Expand Up @@ -64,7 +65,7 @@ def setUp(self) -> None:
def get_runner(self) -> Generator[Runner, None, None]:
with Runner(
SESSION_NAME,
scheduler_factories={"local_dir": LocalScheduler},
scheduler_factories={"local_dir": cast(SchedulerFactory, create_scheduler)},
scheduler_params={
"image_provider_class": LocalDirectoryImageProvider,
},
Expand All @@ -79,14 +80,14 @@ def test_validate_no_roles(self, _) -> None:

def test_validate_no_resource(self, _) -> None:
with self.get_runner() as runner:
role = Role(
"no resource",
image="no_image",
entrypoint="echo",
args=["hello_world"],
)
app = AppDef("no resource", roles=[role])
with self.assertRaises(ValueError):
role = Role(
"no resource",
image="no_image",
entrypoint="echo",
args=["hello_world"],
)
app = AppDef("no resource", roles=[role])
runner.run(app, scheduler="local_dir")

def test_validate_invalid_replicas(self, _) -> None:
Expand Down Expand Up @@ -129,7 +130,7 @@ def test_dryrun(self, _) -> None:
}
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role = Role(
name="touch",
Expand All @@ -149,7 +150,7 @@ def test_dryrun_env_variables(self, _) -> None:
scheduler_mock = MagicMock()
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -178,7 +179,7 @@ def test_dryrun_trackers_parent_run_id_as_paramenter(self, _) -> None:
expected_parent_run_id = "123"
with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -217,7 +218,7 @@ def test_dryrun_setup_trackers(self, config_trackers_mock: MagicMock, _) -> None

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -265,7 +266,7 @@ def test_dryrun_setup_trackers_as_env_variable(self, _) -> None:

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role1 = Role(
name="echo1",
Expand Down Expand Up @@ -333,8 +334,10 @@ def build_workspace_and_update_role(
name=SESSION_NAME,
# pyre-fixme[6]: scheduler factory type
scheduler_factories={
"no-build-img": lambda name: TestScheduler(build_new_img=False),
"builds-img": lambda name: TestScheduler(build_new_img=True),
"no-build-img": lambda name, **kwargs: TestScheduler(
build_new_img=False
),
"builds-img": lambda name, **kwargs: TestScheduler(build_new_img=True),
},
) as runner:
app = AppDef(
Expand Down Expand Up @@ -371,7 +374,7 @@ def test_describe(self, _) -> None:
name="sleep",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="sleep.sh",
entrypoint="sleep",
args=["60"],
)
app = AppDef("sleeper", roles=[role])
Expand All @@ -387,7 +390,7 @@ def test_status(self, _) -> None:
name="sleep",
image=str(self.tmpdir),
resource=resource.SMALL,
entrypoint="sleep.sh",
entrypoint="sleep",
args=["60"],
)
app = AppDef("sleeper", roles=[role])
Expand All @@ -414,7 +417,7 @@ def test_status_ui_url(self, json_dumps_mock: MagicMock, _) -> None:

with Runner(
name="test_ui_url_session",
scheduler_factories={"local_dir": lambda name: mock_scheduler},
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
) as runner:
role = Role(
"ignored",
Expand All @@ -438,7 +441,7 @@ def test_status_structured_msg(self, json_dumps_mock: MagicMock, _) -> None:

with Runner(
name="test_structured_msg",
scheduler_factories={"local_dir": lambda name: mock_scheduler},
scheduler_factories={"local_dir": lambda name, **kwargs: mock_scheduler},
) as runner:
role = Role(
"ignored",
Expand Down Expand Up @@ -485,7 +488,7 @@ def test_log_lines(self, _) -> None:

with Runner(
name=SESSION_NAME,
scheduler_factories={"local_dir": lambda name: scheduler_mock},
scheduler_factories={"local_dir": lambda name, **kwargs: scheduler_mock},
) as runner:
role_name = "trainer"
replica_id = 2
Expand Down Expand Up @@ -529,7 +532,7 @@ def test_list(self, _) -> None:
]
with Runner(
name=SESSION_NAME,
scheduler_factories={"kubernetes": lambda name: scheduler_mock},
scheduler_factories={"kubernetes": lambda name, **kwargs: scheduler_mock},
) as runner:
apps = runner.list("kubernetes")
self.assertEqual(apps, apps_expected)
Expand All @@ -541,8 +544,8 @@ def test_get_schedulers(self, json_dumps_mock: MagicMock, _) -> None:
json_dumps_mock.return_value = "{}"
local_sched_mock = MagicMock()
scheduler_factories = {
"local_dir": lambda name: local_dir_sched_mock,
"local": lambda name: local_sched_mock,
"local_dir": lambda name, **kwargs: local_dir_sched_mock,
"local": lambda name, **kwargs: local_sched_mock,
}
with Runner(
name="test_session", scheduler_factories=scheduler_factories
Expand Down Expand Up @@ -576,8 +579,8 @@ def test_run_from_module(self, _: str) -> None:
def test_run_from_file_no_function_found(self, _) -> None:
local_sched_mock = MagicMock()
schedulers = {
"local_dir": lambda name: local_sched_mock,
"local": lambda name: local_sched_mock,
"local_dir": lambda name, **kwargs: local_sched_mock,
"local": lambda name, **kwargs: local_sched_mock,
}
with Runner(name="test_session", scheduler_factories=schedulers) as runner:
component_path = get_full_path("distributed.py")
Expand All @@ -591,7 +594,7 @@ def test_runner_context_manager(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
with get_runner() as runner:
# force schedulers to load
Expand All @@ -602,17 +605,17 @@ def test_runner_context_manager_with_error(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
with self.assertRaisesRegex(RuntimeError, "foobar"):
with get_runner() as runner:
with get_runner():
raise RuntimeError("foobar")

def test_runner_manual_close(self, _) -> None:
mock_scheduler = MagicMock()
with patch(
GET_SCHEDULER_FACTORIES,
return_value={"local_dir": lambda name: mock_scheduler},
return_value={"local_dir": lambda name, **kwargs: mock_scheduler},
):
runner = get_runner()
# force schedulers to load
Expand Down
3 changes: 2 additions & 1 deletion torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -1184,11 +1184,12 @@ def create_scheduler(
session_name: str,
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
image_provider_class: Callable[[LocalOpts], ImageProvider] = CWDImageProvider,
**kwargs: Any,
) -> LocalScheduler:
return LocalScheduler(
session_name=session_name,
image_provider_class=CWDImageProvider,
image_provider_class=image_provider_class,
cache_size=cache_size,
extra_paths=extra_paths,
)

0 comments on commit d3393fc

Please sign in to comment.