diff --git a/torchx/schedulers/kubernetes_scheduler.py b/torchx/schedulers/kubernetes_scheduler.py index c06eebc8a..14bce2192 100644 --- a/torchx/schedulers/kubernetes_scheduler.py +++ b/torchx/schedulers/kubernetes_scheduler.py @@ -38,12 +38,14 @@ Any, cast, Dict, + Generic, Iterable, List, Mapping, Optional, Tuple, TYPE_CHECKING, + TypeVar, ) import torchx @@ -453,8 +455,7 @@ def app_to_resource( @dataclass -class KubernetesJob: - images_to_push: Dict[str, Tuple[str, str]] +class BaseKubernetesJob: resource: Dict[str, object] def __str__(self) -> str: @@ -472,9 +473,14 @@ class KubernetesOpts(TypedDict, total=False): priority_class: Optional[str] -class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]): +KO = TypeVar("KO", bound=KubernetesOpts) + + +class BaseKubernetesScheduler(Generic[KO], Scheduler[KO]): """ - KubernetesScheduler is a TorchX scheduling interface to Kubernetes. + This scheduler implements the torchx scheduler interface for Kubernetes jobs. + This is the base class that can be shared between different Kubernetes scheduler + implementations that need to support different types of workspaces. Important: Volcano is required to be installed on the Kubernetes cluster. TorchX requires gang scheduling for multi-replica/multi-role execution @@ -543,31 +549,16 @@ class KubernetesScheduler(DockerWorkspaceMixin, Scheduler[KubernetesOpts]): request by a small amount to account for the node reserved CPU and memory. If you run into scheduling issues you may need to reduce the requested CPU and memory from the host values. - - **Compatibility** - - .. compatibility:: - type: scheduler - features: - cancel: true - logs: true - distributed: true - describe: | - Partial support. KubernetesScheduler will return job and replica - status but does not provide the complete original AppSpec. - workspaces: true - mounts: true - elasticity: Requires Volcano >1.6 """ def __init__( self, + backend: str, session_name: str, client: Optional["ApiClient"] = None, - docker_client: Optional["DockerClient"] = None, ) -> None: # NOTE: make sure any new init options are supported in create_scheduler(...) - super().__init__("kubernetes", session_name, docker_client=docker_client) + super().__init__(backend, session_name) self._client = client @@ -604,23 +595,22 @@ def _get_active_context(self) -> Dict[str, Any]: contexts, active_context = config.list_kube_config_contexts() return active_context - def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: + def _schedule( + self, + group: str, + version: str, + namespace: str, + plural: str, + resource: Dict[str, object], + ) -> str: from kubernetes.client.rest import ApiException - cfg = dryrun_info._cfg - assert cfg is not None, f"{dryrun_info} missing cfg" - namespace = cfg.get("namespace") or "default" - - images_to_push = dryrun_info.request.images_to_push - self.push_images(images_to_push) - - resource = dryrun_info.request.resource try: resp = self._custom_objects_api().create_namespaced_custom_object( - group="batch.volcano.sh", - version="v1alpha1", + group=group, + version=version, namespace=namespace, - plural="jobs", + plural=plural, body=resource, ) except ApiException as e: @@ -634,33 +624,6 @@ def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: return f'{namespace}:{resp["metadata"]["name"]}' - def _submit_dryrun( - self, app: AppDef, cfg: KubernetesOpts - ) -> AppDryRunInfo[KubernetesJob]: - queue = cfg.get("queue") - if not isinstance(queue, str): - raise TypeError(f"config value 'queue' must be a string, got {queue}") - - # map any local images to the remote image - images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) - - service_account = cfg.get("service_account") - assert service_account is None or isinstance( - service_account, str - ), "service_account must be a str" - - priority_class = cfg.get("priority_class") - assert priority_class is None or isinstance( - priority_class, str - ), "priority_class must be a str" - - resource = app_to_resource(app, queue, service_account, priority_class) - req = KubernetesJob( - resource=resource, - images_to_push=images_to_push, - ) - return AppDryRunInfo(req, repr) - def _validate(self, app: AppDef, scheduler: str) -> None: # Skip validation step pass @@ -804,6 +767,90 @@ def list(self) -> List[ListAppResponse]: ] +@dataclass +class KubernetesJob(BaseKubernetesJob): + images_to_push: Dict[str, Tuple[str, str]] + + +class KubernetesScheduler( + DockerWorkspaceMixin, BaseKubernetesScheduler[KubernetesOpts] +): + """ + KubernetesScheduler is a TorchX scheduling interface to Kubernetes + using the DockerWorkspaceMixin. + + + **Compatibility** + + .. compatibility:: + type: scheduler + features: + cancel: true + logs: true + distributed: true + describe: | + Partial support. KubernetesScheduler will return job and replica + status but does not provide the complete original AppSpec. + workspaces: true + mounts: true + elasticity: Requires Volcano >1.6 + """ + + def __init__( + self, + session_name: str, + client: Optional["ApiClient"] = None, + docker_client: Optional["DockerClient"] = None, + ) -> None: + # NOTE: make sure any new init options are supported in create_scheduler(...) + super().__init__("kubernetes", session_name, docker_client=docker_client) + + self._client = client + + def schedule(self, dryrun_info: AppDryRunInfo[KubernetesJob]) -> str: + cfg = dryrun_info._cfg + assert cfg is not None, f"{dryrun_info} missing cfg" + namespace = cfg.get("namespace") or "default" + + resource = dryrun_info.request.resource + + images_to_push = dryrun_info.request.images_to_push + self.push_images(images_to_push) + return self._schedule( + group="batch.volcano.sh", + version="v1alpha1", + namespace=str(namespace), + plural="jobs", + resource=resource, + ) + + def _submit_dryrun(self, app: AppDef, cfg: KO) -> AppDryRunInfo[KubernetesJob]: + queue = cfg.get("queue") + if not isinstance(queue, str): + raise TypeError(f"config value 'queue' must be a string, got {queue}") + + # map any local images to the remote image + images_to_push = self.dryrun_push_images(app, cast(Mapping[str, CfgVal], cfg)) + + service_account = cfg.get("service_account") + assert service_account is None or isinstance( + service_account, str + ), "service_account must be a str" + + priority_class = cfg.get("priority_class") + assert priority_class is None or isinstance( + priority_class, str + ), "priority_class must be a str" + + resource = app_to_resource(app, queue, service_account, priority_class) + + req = KubernetesJob( + resource=resource, + images_to_push=images_to_push, + ) + return AppDryRunInfo(req, repr) + + def create_scheduler( session_name: str, client: Optional["ApiClient"] = None,