diff --git a/src/codeflare_sdk/cluster/auth.py b/src/codeflare_sdk/cluster/auth.py index c39fe1d4a..6c02219d2 100644 --- a/src/codeflare_sdk/cluster/auth.py +++ b/src/codeflare_sdk/cluster/auth.py @@ -119,25 +119,14 @@ def login(self) -> str: configuration.host = self.server configuration.api_key["authorization"] = self.token + api_client = client.ApiClient(configuration) if not self.skip_tls: - if self.ca_cert_path is None: - configuration.ssl_ca_cert = None - elif os.path.isfile(self.ca_cert_path): - print( - f"Authenticated with certificate located at {self.ca_cert_path}" - ) - configuration.ssl_ca_cert = self.ca_cert_path - else: - raise FileNotFoundError( - f"Certificate file not found at {self.ca_cert_path}" - ) - configuration.verify_ssl = True + _client_with_cert(api_client, self.ca_cert_path) else: urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning) print("Insecure request warnings have been disabled") configuration.verify_ssl = False - api_client = client.ApiClient(configuration) client.AuthenticationApi(api_client).get_api_group() config_path = None return "Logged into %s" % self.server @@ -211,6 +200,41 @@ def config_check() -> str: return config_path +def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None): + client.configuration.verify_ssl = True + cert_path = _gen_ca_cert_path(ca_cert_path) + if cert_path is None: + client.configuration.ssl_ca_cert = None + elif os.path.isfile(cert_path): + print(f"Authenticated with certificate located at {cert_path}") + client.configuration.ssl_ca_cert = ca_cert_path + else: + raise FileNotFoundError(f"Certificate file not found at {cert_path}") + + +def _gen_ca_cert_path(ca_cert_path: str): + """Gets the path to the default CA certificate file either through env config or default path""" + if ca_cert_path is not None: + return ca_cert_path + elif "CF_SDK_CA_CERT_PATH" in os.environ: + print(f"Using {os.environ.get('CF_SDK_CA_CERT_PATH')}") + return os.environ.get("CF_SDK_CA_CERT_PATH") + elif os.path.exists(WORKBENCH_CA_CERT_PATH): + print("Default path exists, using default path") + return WORKBENCH_CA_CERT_PATH + else: + return None + + +def get_api_client() -> client.ApiClient: + "This function should load the api client with defaults" + if api_client != None: + return api_client + to_return = client.ApiClient() + _client_with_cert(to_return) + return to_return + + def api_config_handler() -> Optional[client.ApiClient]: """ This function is used to load the api client if the user has logged in diff --git a/src/codeflare_sdk/cluster/awload.py b/src/codeflare_sdk/cluster/awload.py index 7455b2161..1ead59146 100644 --- a/src/codeflare_sdk/cluster/awload.py +++ b/src/codeflare_sdk/cluster/awload.py @@ -24,7 +24,7 @@ from kubernetes import client, config from ..utils.kube_api_helpers import _kube_api_error_handling -from .auth import config_check, api_config_handler +from .auth import config_check, get_api_client class AWManager: @@ -59,7 +59,7 @@ def submit(self) -> None: """ try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) api_instance.create_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -84,7 +84,7 @@ def remove(self) -> None: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) api_instance.delete_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 7c652a186..79c353116 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -25,7 +25,7 @@ from kubernetes import config from ray.job_submission import JobSubmissionClient -from .auth import config_check, api_config_handler +from .auth import config_check, get_api_client from ..utils import pretty_print from ..utils.generate_yaml import ( generate_appwrapper, @@ -80,7 +80,7 @@ def __init__(self, config: ClusterConfiguration): @property def _client_headers(self): - k8_client = api_config_handler() or client.ApiClient() + k8_client = get_api_client() return { "Authorization": k8_client.configuration.get_api_key_with_prefix( "authorization" @@ -95,7 +95,7 @@ def _client_verify_tls(self): @property def job_client(self): - k8client = api_config_handler() or client.ApiClient() + k8client = get_api_client() if self._job_submission_client: return self._job_submission_client if is_openshift_cluster(): @@ -141,7 +141,7 @@ def up(self): try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) if self.config.appwrapper: if self.config.write_to_file: with open(self.app_wrapper_yaml) as f: @@ -172,7 +172,7 @@ def up(self): return _kube_api_error_handling(e) def _throw_for_no_raycluster(self): - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) try: api_instance.list_namespaced_custom_object( group="ray.io", @@ -199,7 +199,7 @@ def down(self): self._throw_for_no_raycluster() try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) if self.config.appwrapper: api_instance.delete_namespaced_custom_object( group="workload.codeflare.dev", @@ -358,7 +358,7 @@ def cluster_dashboard_uri(self) -> str: config_check() if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", @@ -380,7 +380,7 @@ def cluster_dashboard_uri(self) -> str: return f"{protocol}://{route['spec']['host']}" else: try: - api_instance = client.NetworkingV1Api(api_config_handler()) + api_instance = client.NetworkingV1Api(get_api_client()) ingresses = api_instance.list_namespaced_ingress(self.config.namespace) except Exception as e: # pragma no cover return _kube_api_error_handling(e) @@ -579,9 +579,6 @@ def get_current_namespace(): # pragma: no cover return active_context except Exception as e: print("Unable to find current namespace") - - if api_config_handler() != None: - return None print("trying to gather from current context") try: _, active_context = config.list_kube_config_contexts(config_check()) @@ -601,7 +598,7 @@ def get_cluster( ): try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -656,7 +653,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA def _check_aw_exists(name: str, namespace: str) -> bool: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -683,7 +680,7 @@ def _get_ingress_domain(self): # pragma: no cover if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", @@ -702,7 +699,7 @@ def _get_ingress_domain(self): # pragma: no cover domain = route["spec"]["host"] else: try: - api_client = client.NetworkingV1Api(api_config_handler()) + api_client = client.NetworkingV1Api(get_api_client()) ingresses = api_client.list_namespaced_ingress(namespace) except Exception as e: # pragma: no cover return _kube_api_error_handling(e) @@ -716,7 +713,7 @@ def _get_ingress_domain(self): # pragma: no cover def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -735,7 +732,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]: def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]: try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -757,7 +754,7 @@ def _get_ray_clusters( list_of_clusters = [] try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) rcs = api_instance.list_namespaced_custom_object( group="ray.io", version="v1", @@ -786,7 +783,7 @@ def _get_app_wrappers( try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) aws = api_instance.list_namespaced_custom_object( group="workload.codeflare.dev", version="v1beta2", @@ -815,7 +812,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: dashboard_url = None if is_openshift_cluster(): try: - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) routes = api_instance.list_namespaced_custom_object( group="route.openshift.io", version="v1", @@ -834,7 +831,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]: dashboard_url = f"{protocol}://{route['spec']['host']}" else: try: - api_instance = client.NetworkingV1Api(api_config_handler()) + api_instance = client.NetworkingV1Api(get_api_client()) ingresses = api_instance.list_namespaced_ingress( rc["metadata"]["namespace"] ) diff --git a/src/codeflare_sdk/utils/generate_cert.py b/src/codeflare_sdk/utils/generate_cert.py index 5de56882b..f3dc80e94 100644 --- a/src/codeflare_sdk/utils/generate_cert.py +++ b/src/codeflare_sdk/utils/generate_cert.py @@ -19,7 +19,7 @@ from cryptography import x509 from cryptography.x509.oid import NameOID import datetime -from ..cluster.auth import config_check, api_config_handler +from ..cluster.auth import config_check, get_api_client from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling @@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30): # oc get secret ca-secret- -o template='{{index .data "ca.key"}}' # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt config_check() - v1 = client.CoreV1Api(api_config_handler()) + v1 = client.CoreV1Api(get_api_client()) # Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret- secret_name = get_secret_name(cluster_name, namespace, v1) diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index c4e1755d8..7a17e0103 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -27,7 +27,7 @@ import uuid from kubernetes import client, config from .kube_api_helpers import _kube_api_error_handling -from ..cluster.auth import api_config_handler, config_check +from ..cluster.auth import get_api_client, config_check from os import urandom from base64 import b64encode from urllib3.util import parse_url @@ -57,7 +57,7 @@ def gen_names(name): def is_openshift_cluster(): try: config_check() - for api in client.ApisApi(api_config_handler()).get_api_versions().groups: + for api in client.ApisApi(get_api_client()).get_api_versions().groups: for v in api.versions: if "route.openshift.io/v1" in v.group_version: return True @@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str): # If the local queue is set, use it. Otherwise, try to use the default queue. try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) local_queues = api_instance.list_namespaced_custom_object( group="kueue.x-k8s.io", version="v1beta1", @@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str): # get all local queues in the namespace try: config_check() - api_instance = client.CustomObjectsApi(api_config_handler()) + api_instance = client.CustomObjectsApi(get_api_client()) local_queues = api_instance.list_namespaced_custom_object( group="kueue.x-k8s.io", version="v1beta1",