Skip to content

Commit

Permalink
add options to create_scheduler so that the get_runner method is full… (
Browse files Browse the repository at this point in the history
#767)

* add options to create_scheduler so that the get_runner method is fully configurable

I also added a note to each schedulers __init__ method to help with maintainablility

Signed-off-by: Kevin <[email protected]>

* add tests for create_scheduler params

Signed-off-by: Kevin <[email protected]>

---------

Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice authored Oct 3, 2023
1 parent a69bb05 commit d713116
Show file tree
Hide file tree
Showing 14 changed files with 81 additions and 13 deletions.
14 changes: 13 additions & 1 deletion torchx/schedulers/aws_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -408,6 +408,7 @@ def __init__(
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("aws_batch", session_name, docker_client=docker_client)

# pyre-fixme[4]: Attribute annotation cannot be `Any`.
Expand Down Expand Up @@ -796,7 +797,18 @@ def _stream_events(
yield event["message"] + "\n"


def create_scheduler(session_name: str, **kwargs: object) -> AWSBatchScheduler:
def create_scheduler(
session_name: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
log_client: Optional[Any] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: object,
) -> AWSBatchScheduler:
return AWSBatchScheduler(
session_name=session_name,
client=client,
log_client=log_client,
docker_client=docker_client,
)
1 change: 1 addition & 0 deletions torchx/schedulers/docker_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,7 @@ class DockerScheduler(DockerWorkspaceMixin, Scheduler[DockerOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("docker", session_name)

def schedule(self, dryrun_info: AppDryRunInfo[DockerJob]) -> str:
Expand Down
9 changes: 8 additions & 1 deletion torchx/schedulers/gcp_batch_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,7 @@ def __init__(
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
Scheduler.__init__(self, "gcp_batch", session_name)
# pyre-fixme[4]: Attribute annotation cannot be `Any`.
self.__client = client
Expand Down Expand Up @@ -474,7 +475,13 @@ def _cancel_existing(self, app_id: str) -> None:
self._client.delete_job(request=request)


def create_scheduler(session_name: str, **kwargs: object) -> GCPBatchScheduler:
def create_scheduler(
session_name: str,
# pyre-fixme[2]: Parameter annotation cannot be `Any`.
client: Optional[Any] = None,
**kwargs: object,
) -> GCPBatchScheduler:
return GCPBatchScheduler(
session_name=session_name,
client=client,
)
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_mcad_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -897,6 +897,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes_mcad", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -1232,9 +1233,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesMCADScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesMCADScheduler:
return KubernetesMCADScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
10 changes: 9 additions & 1 deletion torchx/schedulers/kubernetes_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -540,6 +540,7 @@ def __init__(
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("kubernetes", session_name, docker_client=docker_client)

self._client = client
Expand Down Expand Up @@ -777,9 +778,16 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> KubernetesScheduler:
def create_scheduler(
session_name: str,
client: Optional["ApiClient"] = None,
docker_client: Optional["DockerClient"] = None,
**kwargs: Any,
) -> KubernetesScheduler:
return KubernetesScheduler(
session_name=session_name,
client=client,
docker_client=docker_client,
)


Expand Down
11 changes: 9 additions & 2 deletions torchx/schedulers/local_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -556,6 +556,7 @@ def __init__(
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("local", session_name)

# TODO T72035686 replace dict with a proper LRUCache data structure
Expand Down Expand Up @@ -1124,9 +1125,15 @@ def __next__(self) -> str:
return line


def create_scheduler(session_name: str, **kwargs: Any) -> LocalScheduler:
def create_scheduler(
session_name: str,
cache_size: int = 100,
extra_paths: Optional[List[str]] = None,
**kwargs: Any,
) -> LocalScheduler:
return LocalScheduler(
session_name=session_name,
cache_size=kwargs.get("cache_size", 100),
image_provider_class=CWDImageProvider,
cache_size=cache_size,
extra_paths=extra_paths,
)
1 change: 1 addition & 0 deletions torchx/schedulers/lsf_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,6 +438,7 @@ class LsfScheduler(Scheduler[LsfOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("lsf", session_name)

def _run_opts(self) -> runopts:
Expand Down
7 changes: 5 additions & 2 deletions torchx/schedulers/ray_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,7 @@ class RayScheduler(TmpDirWorkspaceMixin, Scheduler[RayOpts]):
def __init__(
self, session_name: str, ray_client: Optional[JobSubmissionClient] = None
) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("ray", session_name)

# w/o Final None check in _get_ray_client does not work as it pyre assumes mutability
Expand Down Expand Up @@ -441,10 +442,12 @@ def list(self) -> List[ListAppResponse]:
]


def create_scheduler(session_name: str, **kwargs: Any) -> "RayScheduler":
def create_scheduler(
session_name: str, ray_client: Optional[JobSubmissionClient] = None, **kwargs: Any
) -> "RayScheduler":
if not has_ray(): # pragma: no cover
raise ModuleNotFoundError(
"Ray is not installed in the current Python environment."
)

return RayScheduler(session_name=session_name)
return RayScheduler(session_name=session_name, ray_client=ray_client)
1 change: 1 addition & 0 deletions torchx/schedulers/slurm_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -321,6 +321,7 @@ class SlurmScheduler(DirWorkspaceMixin, Scheduler[SlurmOpts]):
"""

def __init__(self, session_name: str) -> None:
# NOTE: make sure any new init options are supported in create_scheduler(...)
super().__init__("slurm", session_name)

def _run_opts(self) -> runopts:
Expand Down
10 changes: 9 additions & 1 deletion torchx/schedulers/test/aws_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,8 +95,16 @@ def paginate(self, *_1: Any, **_2: Any) -> Iterable[Dict[str, Any]]:

class AWSBatchSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
log_client = MagicMock()
docker_client = MagicMock()
scheduler = create_scheduler(
"foo", client=client, log_client=log_client, docker_client=docker_client
)
self.assertIsInstance(scheduler, AWSBatchScheduler)
self.assertEqual(scheduler._client, client)
self.assertEqual(scheduler._log_client, log_client)
self.assertEqual(scheduler._docker_client, docker_client)

def test_submit_dryrun_with_share_id(self) -> None:
app = _test_app()
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/test/gcp_batch_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,8 +57,10 @@ def mock_rand() -> Generator[None, None, None]:

class GCPBatchSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
scheduler = create_scheduler("foo", client=client)
self.assertIsInstance(scheduler, GCPBatchScheduler)
self.assertEqual(scheduler._client, client)

@mock_rand()
def test_submit_dryrun(self) -> None:
Expand Down
6 changes: 5 additions & 1 deletion torchx/schedulers/test/kubernetes_mcad_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,14 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:

class KubernetesMCADSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
docker_client = MagicMock()
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
self.assertIsInstance(
scheduler, kubernetes_mcad_scheduler.KubernetesMCADScheduler
)
self.assertEquals(client, scheduler._client)
self.assertEquals(docker_client, scheduler._docker_client)

def test_app_to_resource_resolved_macros(self) -> None:
app = _test_app()
Expand Down
6 changes: 5 additions & 1 deletion torchx/schedulers/test/kubernetes_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -92,8 +92,12 @@ def _test_app(num_replicas: int = 1) -> specs.AppDef:

class KubernetesSchedulerTest(unittest.TestCase):
def test_create_scheduler(self) -> None:
scheduler = create_scheduler("foo")
client = MagicMock()
docker_client = MagicMock
scheduler = create_scheduler("foo", client=client, docker_client=docker_client)
self.assertIsInstance(scheduler, kubernetes_scheduler.KubernetesScheduler)
self.assertEquals(scheduler._docker_client, docker_client)
self.assertEquals(scheduler._client, client)

def test_app_to_resource_resolved_macros(self) -> None:
app = _test_app()
Expand Down
4 changes: 3 additions & 1 deletion torchx/schedulers/test/local_scheduler_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,9 +174,11 @@ def test_get_entrypoint(self) -> None:
self.assertEqual(self.provider.get_entrypoint("asdf", role), "entrypoint.sh")

def test_create_scheduler(self) -> None:
sched = create_scheduler("foo")
sched = create_scheduler("foo", cache_size=20, extra_paths=["foo"])
self.assertEqual(sched.session_name, "foo")
self.assertEqual(sched._image_provider_class, CWDImageProvider)
self.assertEqual(sched._cache_size, 20)
self.assertEqual(len(sched._extra_paths), 1)


LOCAL_SCHEDULER_MAKE_UNIQUE = "torchx.schedulers.local_scheduler.make_unique"
Expand Down

0 comments on commit d713116

Please sign in to comment.