Skip to content

Commit d713116

Browse files
authored
add options to create_scheduler so that the get_runner method is full… (#767)
* add options to create_scheduler so that the get_runner method is fully configurable I also added a note to each schedulers __init__ method to help with maintainablility Signed-off-by: Kevin <[email protected]> * add tests for create_scheduler params Signed-off-by: Kevin <[email protected]> --------- Signed-off-by: Kevin <[email protected]>
1 parent a69bb05 commit d713116

14 files changed

+81
-13
lines changed

torchx/schedulers/aws_batch_scheduler.py

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -408,6 +408,7 @@ def __init__(
408408
log_client: Optional[Any] = None,
409409
docker_client: Optional["DockerClient"] = None,
410410
) -> None:
411+
# NOTE: make sure any new init options are supported in create_scheduler(...)
411412
super().__init__("aws_batch", session_name, docker_client=docker_client)
412413

413414
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
@@ -796,7 +797,18 @@ def _stream_events(
796797
yield event["message"] + "\n"
797798

798799

799-
def create_scheduler(session_name: str, **kwargs: object) -> AWSBatchScheduler:
800+
def create_scheduler(
801+
session_name: str,
802+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
803+
client: Optional[Any] = None,
804+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
805+
log_client: Optional[Any] = None,
806+
docker_client: Optional["DockerClient"] = None,
807+
**kwargs: object,
808+
) -> AWSBatchScheduler:
800809
return AWSBatchScheduler(
801810
session_name=session_name,
811+
client=client,
812+
log_client=log_client,
813+
docker_client=docker_client,
802814
)

torchx/schedulers/docker_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -168,6 +168,7 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
168168
"""
169169

170170
def __init__(self, session_name: str) -> None:
171+
# NOTE: make sure any new init options are supported in create_scheduler(...)
171172
super().__init__("docker", session_name)
172173

173174
def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:

torchx/schedulers/gcp_batch_scheduler.py

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -142,6 +142,7 @@ def __init__(
142142
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
143143
client: Optional[Any] = None,
144144
) -> None:
145+
# NOTE: make sure any new init options are supported in create_scheduler(...)
145146
Scheduler.__init__(self, "gcp_batch", session_name)
146147
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
147148
self.__client = client
@@ -474,7 +475,13 @@ def _cancel_existing(self, app_id: str) -> None:
474475
self._client.delete_job(request=request)
475476

476477

477-
def create_scheduler(session_name: str, **kwargs: object) -> GCPBatchScheduler:
478+
def create_scheduler(
479+
session_name: str,
480+
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
481+
client: Optional[Any] = None,
482+
**kwargs: object,
483+
) -> GCPBatchScheduler:
478484
return GCPBatchScheduler(
479485
session_name=session_name,
486+
client=client,
480487
)

torchx/schedulers/kubernetes_mcad_scheduler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -897,6 +897,7 @@ def __init__(
897897
client: Optional["ApiClient"] = None,
898898
docker_client: Optional["DockerClient"] = None,
899899
) -> None:
900+
# NOTE: make sure any new init options are supported in create_scheduler(...)
900901
super().__init__("kubernetes_mcad", session_name, docker_client=docker_client)
901902

902903
self._client = client
@@ -1232,9 +1233,16 @@ def list(self) -> List[ListAppResponse]:
12321233
]
12331234

12341235

1235-
def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesMCADScheduler:
1236+
def create_scheduler(
1237+
session_name: str,
1238+
client: Optional["ApiClient"] = None,
1239+
docker_client: Optional["DockerClient"] = None,
1240+
**kwargs: Any,
1241+
) -> KubernetesMCADScheduler:
12361242
return KubernetesMCADScheduler(
12371243
session_name=session_name,
1244+
client=client,
1245+
docker_client=docker_client,
12381246
)
12391247

12401248

torchx/schedulers/kubernetes_scheduler.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -540,6 +540,7 @@ def __init__(
540540
client: Optional["ApiClient"] = None,
541541
docker_client: Optional["DockerClient"] = None,
542542
) -> None:
543+
# NOTE: make sure any new init options are supported in create_scheduler(...)
543544
super().__init__("kubernetes", session_name, docker_client=docker_client)
544545

545546
self._client = client
@@ -777,9 +778,16 @@ def list(self) -> List[ListAppResponse]:
777778
]
778779

779780

780-
def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesScheduler:
781+
def create_scheduler(
782+
session_name: str,
783+
client: Optional["ApiClient"] = None,
784+
docker_client: Optional["DockerClient"] = None,
785+
**kwargs: Any,
786+
) -> KubernetesScheduler:
781787
return KubernetesScheduler(
782788
session_name=session_name,
789+
client=client,
790+
docker_client=docker_client,
783791
)
784792

785793

torchx/schedulers/local_scheduler.py

Lines changed: 9 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -556,6 +556,7 @@ def __init__(
556556
cache_size: int = 100,
557557
extra_paths: Optional[List[str]] = None,
558558
) -> None:
559+
# NOTE: make sure any new init options are supported in create_scheduler(...)
559560
super().__init__("local", session_name)
560561

561562
# TODO T72035686 replace dict with a proper LRUCache data structure
@@ -1124,9 +1125,15 @@ def __next__(self) -> str:
11241125
return line
11251126

11261127

1127-
def create_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
1128+
def create_scheduler(
1129+
session_name: str,
1130+
cache_size: int = 100,
1131+
extra_paths: Optional[List[str]] = None,
1132+
**kwargs: Any,
1133+
) -> LocalScheduler:
11281134
return LocalScheduler(
11291135
session_name=session_name,
1130-
cache_size=kwargs.get("cache_size", 100),
11311136
image_provider_class=CWDImageProvider,
1137+
cache_size=cache_size,
1138+
extra_paths=extra_paths,
11321139
)

torchx/schedulers/lsf_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -438,6 +438,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
438438
"""
439439

440440
def __init__(self, session_name: str) -> None:
441+
# NOTE: make sure any new init options are supported in create_scheduler(...)
441442
super().__init__("lsf", session_name)
442443

443444
def _run_opts(self) -> runopts:

torchx/schedulers/ray_scheduler.py

Lines changed: 5 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -153,6 +153,7 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
153153
def __init__(
154154
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
155155
) -> None:
156+
# NOTE: make sure any new init options are supported in create_scheduler(...)
156157
super().__init__("ray", session_name)
157158

158159
# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
@@ -441,10 +442,12 @@ def list(self) -> List[ListAppResponse]:
441442
]
442443

443444

444-
def create_scheduler(session_name: str, **kwargs: Any) -> "RayScheduler":
445+
def create_scheduler(
446+
session_name: str, ray_client: Optional[JobSubmissionClient] = None, **kwargs: Any
447+
) -> "RayScheduler":
445448
if not has_ray(): # pragma: no cover
446449
raise ModuleNotFoundError(
447450
"Ray is not installed in the current Python environment."
448451
)
449452

450-
return RayScheduler(session_name=session_name)
453+
return RayScheduler(session_name=session_name, ray_client=ray_client)

torchx/schedulers/slurm_scheduler.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -321,6 +321,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
321321
"""
322322

323323
def __init__(self, session_name: str) -> None:
324+
# NOTE: make sure any new init options are supported in create_scheduler(...)
324325
super().__init__("slurm", session_name)
325326

326327
def _run_opts(self) -> runopts:

torchx/schedulers/test/aws_batch_scheduler_test.py

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -95,8 +95,16 @@ def paginate(self, *_1: Any, **_2: Any) -> Iterable[Dict[str, Any]]:
9595

9696
class AWSBatchSchedulerTest(unittest.TestCase):
9797
def test_create_scheduler(self) -> None:
98-
scheduler = create_scheduler("foo")
98+
client = MagicMock()
99+
log_client = MagicMock()
100+
docker_client = MagicMock()
101+
scheduler = create_scheduler(
102+
"foo", client=client, log_client=log_client, docker_client=docker_client
103+
)
99104
self.assertIsInstance(scheduler, AWSBatchScheduler)
105+
self.assertEqual(scheduler._client, client)
106+
self.assertEqual(scheduler._log_client, log_client)
107+
self.assertEqual(scheduler._docker_client, docker_client)
100108

101109
def test_submit_dryrun_with_share_id(self) -> None:
102110
app = _test_app()

torchx/schedulers/test/gcp_batch_scheduler_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -57,8 +57,10 @@ def mock_rand() -> Generator[None, None, None]:
5757

5858
class GCPBatchSchedulerTest(unittest.TestCase):
5959
def test_create_scheduler(self) -> None:
60-
scheduler = create_scheduler("foo")
60+
client = MagicMock()
61+
scheduler = create_scheduler("foo", client=client)
6162
self.assertIsInstance(scheduler, GCPBatchScheduler)
63+
self.assertEqual(scheduler._client, client)
6264

6365
@mock_rand()
6466
def test_submit_dryrun(self) -> None:

torchx/schedulers/test/kubernetes_mcad_scheduler_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -159,10 +159,14 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:
159159

160160
class KubernetesMCADSchedulerTest(unittest.TestCase):
161161
def test_create_scheduler(self) -> None:
162-
scheduler = create_scheduler("foo")
162+
client = MagicMock()
163+
docker_client = MagicMock()
164+
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
163165
self.assertIsInstance(
164166
scheduler, kubernetes_mcad_scheduler.KubernetesMCADScheduler
165167
)
168+
self.assertEquals(client, scheduler._client)
169+
self.assertEquals(docker_client, scheduler._docker_client)
166170

167171
def test_app_to_resource_resolved_macros(self) -> None:
168172
app = _test_app()

torchx/schedulers/test/kubernetes_scheduler_test.py

Lines changed: 5 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -92,8 +92,12 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:
9292

9393
class KubernetesSchedulerTest(unittest.TestCase):
9494
def test_create_scheduler(self) -> None:
95-
scheduler = create_scheduler("foo")
95+
client = MagicMock()
96+
docker_client = MagicMock
97+
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
9698
self.assertIsInstance(scheduler, kubernetes_scheduler.KubernetesScheduler)
99+
self.assertEquals(scheduler._docker_client, docker_client)
100+
self.assertEquals(scheduler._client, client)
97101

98102
def test_app_to_resource_resolved_macros(self) -> None:
99103
app = _test_app()

torchx/schedulers/test/local_scheduler_test.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -174,9 +174,11 @@ def test_get_entrypoint(self) -> None:
174174
self.assertEqual(self.provider.get_entrypoint("asdf", role), "entrypoint.sh")
175175

176176
def test_create_scheduler(self) -> None:
177-
sched = create_scheduler("foo")
177+
sched = create_scheduler("foo", cache_size=20, extra_paths=["foo"])
178178
self.assertEqual(sched.session_name, "foo")
179179
self.assertEqual(sched._image_provider_class, CWDImageProvider)
180+
self.assertEqual(sched._cache_size, 20)
181+
self.assertEqual(len(sched._extra_paths), 1)
180182

181183

182184
LOCAL_SCHEDULER_MAKE_UNIQUE = "torchx.schedulers.local_scheduler.make_unique"

0 commit comments

Comments
 (0)