diff --git a/src/codeflare_sdk/cluster/auth.py b/src/codeflare_sdk/cluster/auth.py index eb739136b..14fee8d2d 100644 --- a/src/codeflare_sdk/cluster/auth.py +++ b/src/codeflare_sdk/cluster/auth.py @@ -25,6 +25,8 @@ import urllib3 from ..utils.kube_api_helpers import _kube_api_error_handling +from typing import Optional + global api_client api_client = None global config_path @@ -183,12 +185,13 @@ def config_check() -> str: raise PermissionError( "Action not permitted, have you put in correct/up-to-date auth credentials?" ) + api_client = config.new_client_from_config() if config_path != None and api_client == None: return config_path -def api_config_handler() -> str: +def api_config_handler() -> Optional[client.ApiClient]: """ This function is used to load the api client if the user has logged in """ diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 5d00cdae8..268b1426c 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -21,12 +21,20 @@ from time import sleep from typing import List, Optional, Tuple, Dict +import openshift as oc +from kubernetes import config from ray.job_submission import JobSubmissionClient +import urllib3 from .auth import config_check, api_config_handler from ..utils import pretty_print from ..utils.generate_yaml import generate_appwrapper from ..utils.kube_api_helpers import _kube_api_error_handling +from ..utils.openshift_oauth import ( + create_openshift_oauth_objects, + delete_openshift_oauth_objects, + download_tls_cert, +) from .config import ClusterConfiguration from .model import ( AppWrapper, @@ -40,6 +48,8 @@ import os import requests +from kubernetes import config + class Cluster: """ @@ -61,6 +71,38 @@ def __init__(self, config: ClusterConfiguration): self.config = config self.app_wrapper_yaml = self.create_app_wrapper() self.app_wrapper_name = self.app_wrapper_yaml.split(".")[0] + self._client = None + + @property + def _client_headers(self): + return { + "Authorization": api_config_handler().configuration.get_api_key_with_prefix( + "authorization" + ) + } + + @property + def _client_verify_tls(self): + return not self.config.openshift_oauth + + @property + def client(self): + if self._client: + return self._client + if self.config.openshift_oauth: + print( + api_config_handler().configuration.get_api_key_with_prefix( + "authorization" + ) + ) + self._client = JobSubmissionClient( + self.cluster_dashboard_uri(), + headers=self._client_headers, + verify=self._client_verify_tls, + ) + else: + self._client = JobSubmissionClient(self.cluster_dashboard_uri()) + return self._client def evaluate_dispatch_priority(self): priority_class = self.config.dispatch_priority @@ -147,6 +189,7 @@ def create_app_wrapper(self): image_pull_secrets=image_pull_secrets, dispatch_priority=dispatch_priority, priority_val=priority_val, + openshift_oauth=self.config.openshift_oauth, ) # creates a new cluster with the provided or default spec @@ -156,6 +199,11 @@ def up(self): the MCAD queue. """ namespace = self.config.namespace + if self.config.openshift_oauth: + create_openshift_oauth_objects( + cluster_name=self.config.name, namespace=namespace + ) + try: config_check() api_instance = client.CustomObjectsApi(api_config_handler()) @@ -190,6 +238,11 @@ def down(self): except Exception as e: # pragma: no cover return _kube_api_error_handling(e) + if self.config.openshift_oauth: + delete_openshift_oauth_objects( + cluster_name=self.config.name, namespace=namespace + ) + def status( self, print_to_console: bool = True ) -> Tuple[CodeFlareClusterStatus, bool]: @@ -258,7 +311,16 @@ def status( return status, ready def is_dashboard_ready(self) -> bool: - response = requests.get(self.cluster_dashboard_uri(), timeout=5) + try: + response = requests.get( + self.cluster_dashboard_uri(), + headers=self._client_headers, + timeout=5, + verify=self._client_verify_tls, + ) + except requests.exceptions.SSLError: + # SSL exception occurs when oauth ingress has been created but cluster is not up + return False if response.status_code == 200: return True else: @@ -330,7 +392,13 @@ def cluster_dashboard_uri(self) -> str: return _kube_api_error_handling(e) for route in routes["items"]: - if route["metadata"]["name"] == f"ray-dashboard-{self.config.name}": + if route["metadata"][ + "name" + ] == f"ray-dashboard-{self.config.name}" or route["metadata"][ + "name" + ].startswith( + f"{self.config.name}-ingress" + ): protocol = "https" if route["spec"].get("tls") else "http" return f"{protocol}://{route['spec']['host']}" return "Dashboard route not available yet, have you run cluster.up()?" @@ -339,30 +407,24 @@ def list_jobs(self) -> List: """ This method accesses the head ray node in your cluster and lists the running jobs. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.list_jobs() + return self.client.list_jobs() def job_status(self, job_id: str) -> str: """ This method accesses the head ray node in your cluster and returns the job status for the provided job id. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.get_job_status(job_id) + return self.client.get_job_status(job_id) def job_logs(self, job_id: str) -> str: """ This method accesses the head ray node in your cluster and returns the logs for the provided job id. """ - dashboard_route = self.cluster_dashboard_uri() - client = JobSubmissionClient(dashboard_route) - return client.get_job_logs(job_id) + return self.client.get_job_logs(job_id) def torchx_config( self, working_dir: str = None, requirements: str = None ) -> Dict[str, str]: - dashboard_address = f"{self.cluster_dashboard_uri().lstrip('http://')}" + dashboard_address = urllib3.util.parse_url(self.cluster_dashboard_uri()).host to_return = { "cluster_name": self.config.name, "dashboard_address": dashboard_address, @@ -474,8 +536,8 @@ def get_current_namespace(): # pragma: no cover def get_cluster(cluster_name: str, namespace: str = "default"): try: - config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + config.load_kube_config() + api_instance = client.CustomObjectsApi() rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1alpha1", @@ -496,7 +558,7 @@ def get_cluster(cluster_name: str, namespace: str = "default"): # private methods def _get_ingress_domain(): try: - config_check() + config.load_kube_config() api_client = client.CustomObjectsApi(api_config_handler()) ingress = api_client.get_cluster_custom_object( "config.openshift.io", "v1", "ingresses", "cluster" @@ -591,7 +653,7 @@ def _get_app_wrappers( def _map_to_ray_cluster(rc) -> Optional[RayCluster]: - if "status" in rc and "state" in rc["status"]: + if "state" in rc["status"]: status = RayClusterStatus(rc["status"]["state"].lower()) else: status = RayClusterStatus.UNKNOWN @@ -606,7 +668,13 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: ) ray_route = None for route in routes["items"]: - if route["metadata"]["name"] == f"ray-dashboard-{rc['metadata']['name']}": + if route["metadata"][ + "name" + ] == f"ray-dashboard-{rc['metadata']['name']}" or route["metadata"][ + "name" + ].startswith( + f"{rc['metadata']['name']}-ingress" + ): protocol = "https" if route["spec"].get("tls") else "http" ray_route = f"{protocol}://{route['spec']['host']}" diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index bde3f4ca0..288541a90 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -51,3 +51,4 @@ class ClusterConfiguration: local_interactive: bool = False image_pull_secrets: list = field(default_factory=list) dispatch_priority: str = None + openshift_oauth: bool = False # NOTE: to use the user must have permission to create a RoleBinding for system:auth-delegator diff --git a/src/codeflare_sdk/job/jobs.py b/src/codeflare_sdk/job/jobs.py index b9bb9cdc1..10b392d32 100644 --- a/src/codeflare_sdk/job/jobs.py +++ b/src/codeflare_sdk/job/jobs.py @@ -18,15 +18,20 @@ from pathlib import Path from torchx.components.dist import ddp -from torchx.runner import get_runner +from torchx.runner import get_runner, Runner +from torchx.schedulers.ray_scheduler import RayScheduler from torchx.specs import AppHandle, parse_app_handle, AppDryRunInfo +from ray.job_submission import JobSubmissionClient + +import openshift as oc + if TYPE_CHECKING: from ..cluster.cluster import Cluster from ..cluster.cluster import get_current_namespace +from ..utils.openshift_oauth import download_tls_cert all_jobs: List["Job"] = [] -torchx_runner = get_runner() class JobDefinition(metaclass=abc.ABCMeta): @@ -92,30 +97,37 @@ def __init__( def _dry_run(self, cluster: "Cluster"): j = f"{cluster.config.num_workers}x{max(cluster.config.num_gpus, 1)}" # # of proc. = # of gpus - return torchx_runner.dryrun( - app=ddp( - *self.script_args, - script=self.script, - m=self.m, - name=self.name, - h=self.h, - cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus, - gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus, - memMB=self.memMB - if self.memMB is not None - else cluster.config.max_memory * 1024, - j=self.j if self.j is not None else j, - env=self.env, - max_retries=self.max_retries, - rdzv_port=self.rdzv_port, - rdzv_backend=self.rdzv_backend - if self.rdzv_backend is not None - else "static", - mounts=self.mounts, + runner = get_runner(ray_client=cluster.client) + runner._scheduler_instances["ray"] = RayScheduler( + session_name=runner._name, ray_client=cluster.client + ) + return ( + runner.dryrun( + app=ddp( + *self.script_args, + script=self.script, + m=self.m, + name=self.name, + h=self.h, + cpu=self.cpu if self.cpu is not None else cluster.config.max_cpus, + gpu=self.gpu if self.gpu is not None else cluster.config.num_gpus, + memMB=self.memMB + if self.memMB is not None + else cluster.config.max_memory * 1024, + j=self.j if self.j is not None else j, + env=self.env, + max_retries=self.max_retries, + rdzv_port=self.rdzv_port, + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "static", + mounts=self.mounts, + ), + scheduler=cluster.torchx_scheduler, + cfg=cluster.torchx_config(**self.scheduler_args), + workspace=self.workspace, ), - scheduler=cluster.torchx_scheduler, - cfg=cluster.torchx_config(**self.scheduler_args), - workspace=self.workspace, + runner, ) def _missing_spec(self, spec: str): @@ -125,41 +137,47 @@ def _dry_run_no_cluster(self): if self.scheduler_args is not None: if self.scheduler_args.get("namespace") is None: self.scheduler_args["namespace"] = get_current_namespace() - return torchx_runner.dryrun( - app=ddp( - *self.script_args, - script=self.script, - m=self.m, - name=self.name if self.name is not None else self._missing_spec("name"), - h=self.h, - cpu=self.cpu - if self.cpu is not None - else self._missing_spec("cpu (# cpus per worker)"), - gpu=self.gpu - if self.gpu is not None - else self._missing_spec("gpu (# gpus per worker)"), - memMB=self.memMB - if self.memMB is not None - else self._missing_spec("memMB (memory in MB)"), - j=self.j - if self.j is not None - else self._missing_spec( - "j (`workers`x`procs`)" - ), # # of proc. = # of gpus, - env=self.env, # should this still exist? - max_retries=self.max_retries, - rdzv_port=self.rdzv_port, # should this still exist? - rdzv_backend=self.rdzv_backend - if self.rdzv_backend is not None - else "c10d", - mounts=self.mounts, - image=self.image - if self.image is not None - else self._missing_spec("image"), + runner = get_runner() + return ( + runner.dryrun( + app=ddp( + *self.script_args, + script=self.script, + m=self.m, + name=self.name + if self.name is not None + else self._missing_spec("name"), + h=self.h, + cpu=self.cpu + if self.cpu is not None + else self._missing_spec("cpu (# cpus per worker)"), + gpu=self.gpu + if self.gpu is not None + else self._missing_spec("gpu (# gpus per worker)"), + memMB=self.memMB + if self.memMB is not None + else self._missing_spec("memMB (memory in MB)"), + j=self.j + if self.j is not None + else self._missing_spec( + "j (`workers`x`procs`)" + ), # # of proc. = # of gpus, + env=self.env, # should this still exist? + max_retries=self.max_retries, + rdzv_port=self.rdzv_port, # should this still exist? + rdzv_backend=self.rdzv_backend + if self.rdzv_backend is not None + else "c10d", + mounts=self.mounts, + image=self.image + if self.image is not None + else self._missing_spec("image"), + ), + scheduler="kubernetes_mcad", + cfg=self.scheduler_args, + workspace="", ), - scheduler="kubernetes_mcad", - cfg=self.scheduler_args, - workspace="", + runner, ) def submit(self, cluster: "Cluster" = None) -> "Job": @@ -171,18 +189,20 @@ def __init__(self, job_definition: "DDPJobDefinition", cluster: "Cluster" = None self.job_definition = job_definition self.cluster = cluster if self.cluster: - self._app_handle = torchx_runner.schedule(job_definition._dry_run(cluster)) + definition, runner = job_definition._dry_run(cluster) + self._app_handle = runner.schedule(definition) + self._runner = runner else: - self._app_handle = torchx_runner.schedule( - job_definition._dry_run_no_cluster() - ) + definition, runner = job_definition._dry_run_no_cluster() + self._app_handle = runner.schedule(definition) + self._runner = runner all_jobs.append(self) def status(self) -> str: - return torchx_runner.status(self._app_handle) + return self._runner.status(self._app_handle) def logs(self) -> str: - return "".join(torchx_runner.log_lines(self._app_handle, None)) + return "".join(self._runner.log_lines(self._app_handle, None)) def cancel(self): - torchx_runner.cancel(self._app_handle) + self._runner.cancel(self._app_handle) diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 95e1c5ecb..a4a2da825 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -24,6 +24,11 @@ from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling from ..cluster.auth import api_config_handler, config_check +from os import urandom +from base64 import b64encode +from urllib3.util import parse_url + +from kubernetes import client, config def read_template(template): @@ -46,13 +51,17 @@ def gen_names(name): def update_dashboard_route(route_item, cluster_name, namespace): metadata = route_item.get("generictemplate", {}).get("metadata") - metadata["name"] = f"ray-dashboard-{cluster_name}" + metadata["name"] = gen_dashboard_route_name(cluster_name) metadata["namespace"] = namespace metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc" spec = route_item.get("generictemplate", {}).get("spec") spec["to"]["name"] = f"{cluster_name}-head-svc" +def gen_dashboard_route_name(cluster_name): + return f"ray-dashboard-{cluster_name}" + + # ToDo: refactor the update_x_route() functions def update_rayclient_route(route_item, cluster_name, namespace): metadata = route_item.get("generictemplate", {}).get("metadata") @@ -369,6 +378,83 @@ def write_user_appwrapper(user_yaml, output_file_name): print(f"Written to: {output_file_name}") +def enable_openshift_oauth(user_yaml, cluster_name, namespace): + config_check() + k8_client = api_config_handler() + tls_mount_location = "/etc/tls/private" + oauth_port = 8443 + oauth_sa_name = f"{cluster_name}-oauth-proxy" + tls_secret_name = f"{cluster_name}-proxy-tls-secret" + tls_volume_name = "proxy-tls-secret" + port_name = "oauth-proxy" + _, _, host, _, _, _, _ = parse_url(k8_client.configuration.host) + host = host.replace( + "api.", f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps." + ) + oauth_sidecar = _create_oauth_sidecar_object( + namespace, + tls_mount_location, + oauth_port, + oauth_sa_name, + tls_volume_name, + port_name, + ) + tls_secret_volume = client.V1Volume( + name=tls_volume_name, + secret=client.V1SecretVolumeSource(secret_name=tls_secret_name), + ) + # allows for setting value of Cluster object when initializing object from an existing AppWrapper on cluster + user_yaml["metadata"]["annotations"] = user_yaml["metadata"].get("annotations", {}) + user_yaml["metadata"]["annotations"][ + "codeflare-sdk-use-oauth" + ] = "true" # if the user gets an + ray_headgroup_pod = user_yaml["spec"]["resources"]["GenericItems"][0][ + "generictemplate" + ]["spec"]["headGroupSpec"]["template"]["spec"] + user_yaml["spec"]["resources"]["GenericItems"].pop(1) + ray_headgroup_pod["serviceAccount"] = oauth_sa_name + ray_headgroup_pod["volumes"] = ray_headgroup_pod.get("volumes", []) + ray_headgroup_pod["volumes"].append( + k8_client.sanitize_for_serialization(tls_secret_volume) + ) + ray_headgroup_pod["containers"].append( + k8_client.sanitize_for_serialization(oauth_sidecar) + ) + # add volume to headnode + # add sidecar container to ray object + + +def _create_oauth_sidecar_object( + namespace: str, + tls_mount_location: str, + oauth_port: int, + oauth_sa_name: str, + tls_volume_name: str, + port_name: str, +) -> client.V1Container: + return client.V1Container( + args=[ + f"--https-address=:{oauth_port}", + "--provider=openshift", + f"--openshift-service-account={oauth_sa_name}", + "--upstream=http://localhost:8265", + f"--tls-cert={tls_mount_location}/tls.crt", + f"--tls-key={tls_mount_location}/tls.key", + f"--cookie-secret={b64encode(urandom(64)).decode('utf-8')}", # create random string for encrypting cookie + f'--openshift-delegate-urls={{"/":{{"resource":"pods","namespace":"{namespace}","verb":"get"}}}}', + ], + image="registry.redhat.io/openshift4/ose-oauth-proxy@sha256:1ea6a01bf3e63cdcf125c6064cbd4a4a270deaf0f157b3eabb78f60556840366", + name="oauth-proxy", + ports=[client.V1ContainerPort(container_port=oauth_port, name=port_name)], + resources=client.V1ResourceRequirements(limits=None, requests=None), + volume_mounts=[ + client.V1VolumeMount( + mount_path=tls_mount_location, name=tls_volume_name, read_only=True + ) + ], + ) + + def generate_appwrapper( name: str, namespace: str, @@ -390,6 +476,7 @@ def generate_appwrapper( image_pull_secrets: list, dispatch_priority: str, priority_val: int, + openshift_oauth: bool, ): user_yaml = read_template(template) appwrapper_name, cluster_name = gen_names(name) @@ -433,6 +520,10 @@ def generate_appwrapper( enable_local_interactive(resources, cluster_name, namespace) else: disable_raycluster_tls(resources["resources"]) + + if openshift_oauth: + enable_openshift_oauth(user_yaml, cluster_name, namespace) + outfile = appwrapper_name + ".yaml" write_user_appwrapper(user_yaml, outfile) return outfile diff --git a/src/codeflare_sdk/utils/openshift_oauth.py b/src/codeflare_sdk/utils/openshift_oauth.py new file mode 100644 index 000000000..3fd6e008a --- /dev/null +++ b/src/codeflare_sdk/utils/openshift_oauth.py @@ -0,0 +1,229 @@ +from urllib3.util import parse_url +from .generate_yaml import gen_dashboard_route_name +from base64 import b64decode + +from ..cluster.auth import config_check, api_config_handler + +from kubernetes import client + + +def create_openshift_oauth_objects(cluster_name, namespace): + config_check() + api_client = api_config_handler() + oauth_port = 8443 + oauth_sa_name = f"{cluster_name}-oauth-proxy" + tls_secret_name = _gen_tls_secret_name(cluster_name) + service_name = f"{cluster_name}-oauth" + port_name = "oauth-proxy" + host = parse_url(api_client.configuration.host).host + + # replace "^api" with the expected host + host = f"{gen_dashboard_route_name(cluster_name)}-{namespace}.apps" + host.lstrip( + "api" + ) + + _create_or_replace_oauth_sa(namespace, oauth_sa_name, host) + _create_or_replace_oauth_service_obj( + cluster_name, namespace, oauth_port, tls_secret_name, service_name, port_name + ) + _create_or_replace_oauth_ingress_object( + cluster_name, namespace, service_name, port_name, host + ) + _create_or_replace_oauth_rb(cluster_name, namespace, oauth_sa_name) + + +def _create_or_replace_oauth_sa(namespace, oauth_sa_name, host): + api_client = api_config_handler() + oauth_sa = client.V1ServiceAccount( + api_version="v1", + kind="ServiceAccount", + metadata=client.V1ObjectMeta( + name=oauth_sa_name, + namespace=namespace, + annotations={ + "serviceaccounts.openshift.io/oauth-redirecturi.first": f"https://{host}" + }, + ), + ) + try: + client.CoreV1Api(api_client).create_namespaced_service_account( + namespace=namespace, body=oauth_sa + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.CoreV1Api(api_client).replace_namespaced_service_account( + namespace=namespace, + body=oauth_sa, + name=oauth_sa_name, + ) + else: + raise e + + +def _create_or_replace_oauth_rb(cluster_name, namespace, oauth_sa_name): + api_client = api_config_handler() + oauth_crb = client.V1ClusterRoleBinding( + api_version="rbac.authorization.k8s.io/v1", + kind="ClusterRoleBinding", + metadata=client.V1ObjectMeta(name=f"{cluster_name}-rb"), + role_ref=client.V1RoleRef( + api_group="rbac.authorization.k8s.io", + kind="ClusterRole", + name="system:auth-delegator", + ), + subjects=[ + client.V1Subject( + kind="ServiceAccount", name=oauth_sa_name, namespace=namespace + ) + ], + ) + try: + client.RbacAuthorizationV1Api(api_client).create_cluster_role_binding( + body=oauth_crb + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.RbacAuthorizationV1Api(api_client).replace_cluster_role_binding( + body=oauth_crb, name=f"{cluster_name}-rb" + ) + else: + raise e + + +def _gen_tls_secret_name(cluster_name): + return f"{cluster_name}-proxy-tls-secret" + + +def delete_openshift_oauth_objects(cluster_name, namespace): + # NOTE: it might be worth adding error handling here, but shouldn't be necessary because cluster.down(...) checks + # for an existing cluster before calling this => the objects should never be deleted twice + api_client = api_config_handler() + oauth_sa_name = f"{cluster_name}-oauth-proxy" + service_name = f"{cluster_name}-oauth" + client.CoreV1Api(api_client).delete_namespaced_service_account( + name=oauth_sa_name, namespace=namespace + ) + client.CoreV1Api(api_client).delete_namespaced_service( + name=service_name, namespace=namespace + ) + client.NetworkingV1Api(api_client).delete_namespaced_ingress( + name=f"{cluster_name}-ingress", namespace=namespace + ) + client.RbacAuthorizationV1Api(api_client).delete_cluster_role_binding( + name=f"{cluster_name}-rb" + ) + + +def download_tls_cert(cluster_name, namespace, output_file): + api_client = api_config_handler() + b64_tls_cert = ( + client.CoreV1Api(api_client) + .read_namespaced_secret( + name=_gen_tls_secret_name(cluster_name=cluster_name), namespace=namespace + ) + .data["tls.crt"] + ) + with open(output_file, "w+") as f: + f.write(b64decode(b64_tls_cert).decode("ascii")) + + +def _create_or_replace_oauth_service_obj( + cluster_name: str, + namespace: str, + oauth_port: int, + tls_secret_name: str, + service_name: str, + port_name: str, +) -> client.V1Service: + api_client = api_config_handler() + oauth_service = client.V1Service( + api_version="v1", + kind="Service", + metadata=client.V1ObjectMeta( + annotations={ + "service.beta.openshift.io/serving-cert-secret-name": tls_secret_name + }, + name=service_name, + namespace=namespace, + ), + spec=client.V1ServiceSpec( + ports=[ + client.V1ServicePort( + name=port_name, + protocol="TCP", + port=443, + target_port=oauth_port, + ) + ], + selector={ + "app.kubernetes.io/created-by": "kuberay-operator", + "app.kubernetes.io/name": "kuberay", + "ray.io/cluster": cluster_name, + "ray.io/identifier": f"{cluster_name}-head", + "ray.io/node-type": "head", + }, + ), + ) + try: + client.CoreV1Api(api_client).create_namespaced_service( + namespace=namespace, body=oauth_service + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.CoreV1Api(api_client).replace_namespaced_service( + namespace=namespace, body=oauth_service, name=service_name + ) + else: + raise e + + +def _create_or_replace_oauth_ingress_object( + cluster_name: str, + namespace: str, + service_name: str, + port_name: str, + host: str, +) -> client.V1Ingress: + api_client = api_config_handler() + ingress = client.V1Ingress( + api_version="networking.k8s.io/v1", + kind="Ingress", + metadata=client.V1ObjectMeta( + annotations={"route.openshift.io/termination": "passthrough"}, + name=f"{cluster_name}-ingress", + namespace=namespace, + ), + spec=client.V1IngressSpec( + rules=[ + client.V1IngressRule( + host=host, + http=client.V1HTTPIngressRuleValue( + paths=[ + client.V1HTTPIngressPath( + backend=client.V1IngressBackend( + service=client.V1IngressServiceBackend( + name=service_name, + port=client.V1ServiceBackendPort( + name=port_name + ), + ) + ), + path_type="ImplementationSpecific", + ) + ] + ), + ) + ] + ), + ) + try: + client.NetworkingV1Api(api_client).create_namespaced_ingress( + namespace=namespace, body=ingress + ) + except client.ApiException as e: + if e.reason == "Conflict": + client.NetworkingV1Api(api_client).replace_namespaced_ingress( + namespace=namespace, body=ingress, name=f"{cluster_name}-ingress" + ) + else: + raise e diff --git a/tests/unit_test.py b/tests/unit_test.py index 78925226a..35ee3db03 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -1848,7 +1848,7 @@ def test_DDPJobDefinition_dry_run(mocker): ) ddp = createTestDDP() cluster = createClusterWithConfig() - ddp_job = ddp._dry_run(cluster) + ddp_job, _ = ddp._dry_run(cluster) assert type(ddp_job) == AppDryRunInfo assert ddp_job._fmt is not None assert type(ddp_job.request) == RayJob @@ -1932,7 +1932,7 @@ def test_DDPJobDefinition_dry_run_no_resource_args(mocker): rdzv_port=29500, scheduler_args={"requirements": "test"}, ) - ddp_job = ddp._dry_run(cluster) + ddp_job, _ = ddp._dry_run(cluster) assert ddp_job._app.roles[0].resource.cpu == cluster.config.max_cpus assert ddp_job._app.roles[0].resource.gpu == cluster.config.num_gpus