Skip to content

Commit

Permalink
wrap api client to add defaults
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice committed Sep 25, 2024
1 parent 1235fc8 commit 19c2abe
Show file tree
Hide file tree
Showing 6 changed files with 66 additions and 67 deletions.
61 changes: 31 additions & 30 deletions src/codeflare_sdk/cluster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,17 +93,7 @@ def __init__(
self.token = token
self.server = server
self.skip_tls = skip_tls
self.ca_cert_path = self._gen_ca_cert_path(ca_cert_path)

def _gen_ca_cert_path(self, ca_cert_path: str):
if ca_cert_path is not None:
return ca_cert_path
elif "CF_SDK_CA_CERT_PATH" in os.environ:
return os.environ.get("CF_SDK_CA_CERT_PATH")
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
return WORKBENCH_CA_CERT_PATH
else:
return None
self.ca_cert_path = _gen_ca_cert_path(ca_cert_path)

def login(self) -> str:
"""
Expand All @@ -119,25 +109,14 @@ def login(self) -> str:
configuration.host = self.server
configuration.api_key["authorization"] = self.token

api_client = client.ApiClient(configuration)
if not self.skip_tls:
if self.ca_cert_path is None:
configuration.ssl_ca_cert = None
elif os.path.isfile(self.ca_cert_path):
print(
f"Authenticated with certificate located at {self.ca_cert_path}"
)
configuration.ssl_ca_cert = self.ca_cert_path
else:
raise FileNotFoundError(
f"Certificate file not found at {self.ca_cert_path}"
)
configuration.verify_ssl = True
_client_with_cert(api_client, self.ca_cert_path)
else:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
print("Insecure request warnings have been disabled")
configuration.verify_ssl = False

api_client = client.ApiClient(configuration)
client.AuthenticationApi(api_client).get_api_group()
config_path = None
return "Logged into %s" % self.server
Expand Down Expand Up @@ -211,11 +190,33 @@ def config_check() -> str:
return config_path


def api_config_handler() -> Optional[client.ApiClient]:
"""
This function is used to load the api client if the user has logged in
"""
if api_client != None and config_path == None:
return api_client
def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
client.configuration.verify_ssl = True
cert_path = _gen_ca_cert_path(ca_cert_path)
if cert_path is None:
client.configuration.ssl_ca_cert = None
elif os.path.isfile(cert_path):
client.configuration.ssl_ca_cert = cert_path
else:
raise FileNotFoundError(f"Certificate file not found at {cert_path}")


def _gen_ca_cert_path(ca_cert_path: Optional[str]):
"""Gets the path to the default CA certificate file either through env config or default path"""
if ca_cert_path is not None:
return ca_cert_path
elif "CF_SDK_CA_CERT_PATH" in os.environ:
return os.environ.get("CF_SDK_CA_CERT_PATH")
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
return WORKBENCH_CA_CERT_PATH
else:
return None


def get_api_client() -> client.ApiClient:
"This function should load the api client with defaults"
if api_client != None:
return api_client
to_return = client.ApiClient()
_client_with_cert(to_return)
return to_return
6 changes: 3 additions & 3 deletions src/codeflare_sdk/cluster/awload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from kubernetes import client, config
from ..utils.kube_api_helpers import _kube_api_error_handling
from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client


class AWManager:
Expand Down Expand Up @@ -59,7 +59,7 @@ def submit(self) -> None:
"""
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.create_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -84,7 +84,7 @@ def remove(self) -> None:

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down
39 changes: 18 additions & 21 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from kubernetes import config
from ray.job_submission import JobSubmissionClient

from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client
from ..utils import pretty_print
from ..utils.generate_yaml import (
generate_appwrapper,
Expand Down Expand Up @@ -80,7 +80,7 @@ def __init__(self, config: ClusterConfiguration):

@property
def _client_headers(self):
k8_client = api_config_handler() or client.ApiClient()
k8_client = get_api_client()
return {
"Authorization": k8_client.configuration.get_api_key_with_prefix(
"authorization"
Expand All @@ -95,7 +95,7 @@ def _client_verify_tls(self):

@property
def job_client(self):
k8client = api_config_handler() or client.ApiClient()
k8client = get_api_client()
if self._job_submission_client:
return self._job_submission_client
if is_openshift_cluster():
Expand Down Expand Up @@ -141,7 +141,7 @@ def up(self):

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
if self.config.write_to_file:
with open(self.app_wrapper_yaml) as f:
Expand Down Expand Up @@ -172,7 +172,7 @@ def up(self):
return _kube_api_error_handling(e)

def _throw_for_no_raycluster(self):
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
try:
api_instance.list_namespaced_custom_object(
group="ray.io",
Expand All @@ -199,7 +199,7 @@ def down(self):
self._throw_for_no_raycluster()
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
Expand Down Expand Up @@ -358,7 +358,7 @@ def cluster_dashboard_uri(self) -> str:
config_check()
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -380,7 +380,7 @@ def cluster_dashboard_uri(self) -> str:
return f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
except Exception as e: # pragma no cover
return _kube_api_error_handling(e)
Expand Down Expand Up @@ -579,9 +579,6 @@ def get_current_namespace(): # pragma: no cover
return active_context
except Exception as e:
print("Unable to find current namespace")

if api_config_handler() != None:
return None
print("trying to gather from current context")
try:
_, active_context = config.list_kube_config_contexts(config_check())
Expand All @@ -601,7 +598,7 @@ def get_cluster(
):
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -656,7 +653,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
def _check_aw_exists(name: str, namespace: str) -> bool:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -683,7 +680,7 @@ def _get_ingress_domain(self): # pragma: no cover

if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())

routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
Expand All @@ -702,7 +699,7 @@ def _get_ingress_domain(self): # pragma: no cover
domain = route["spec"]["host"]
else:
try:
api_client = client.NetworkingV1Api(api_config_handler())
api_client = client.NetworkingV1Api(get_api_client())
ingresses = api_client.list_namespaced_ingress(namespace)
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)
Expand All @@ -716,7 +713,7 @@ def _get_ingress_domain(self): # pragma: no cover
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -735,7 +732,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand All @@ -757,7 +754,7 @@ def _get_ray_clusters(
list_of_clusters = []
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -786,7 +783,7 @@ def _get_app_wrappers(

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down Expand Up @@ -815,7 +812,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = None
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -834,7 +831,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(
rc["metadata"]["namespace"]
)
Expand Down
4 changes: 2 additions & 2 deletions src/codeflare_sdk/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cryptography import x509
from cryptography.x509.oid import NameOID
import datetime
from ..cluster.auth import config_check, api_config_handler
from ..cluster.auth import config_check, get_api_client
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling

Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
config_check()
v1 = client.CoreV1Api(api_config_handler())
v1 = client.CoreV1Api(get_api_client())

# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
secret_name = get_secret_name(cluster_name, namespace, v1)
Expand Down
8 changes: 4 additions & 4 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import uuid
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling
from ..cluster.auth import api_config_handler, config_check
from ..cluster.auth import get_api_client, config_check
from os import urandom
from base64 import b64encode
from urllib3.util import parse_url
Expand Down Expand Up @@ -57,7 +57,7 @@ def gen_names(name):
def is_openshift_cluster():
try:
config_check()
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
for v in api.versions:
if "route.openshift.io/v1" in v.group_version:
return True
Expand Down Expand Up @@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand All @@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
# get all local queues in the namespace
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand Down
Loading

0 comments on commit 19c2abe

Please sign in to comment.