Skip to content

Commit

Permalink
wrap api client to add defaults
Browse files Browse the repository at this point in the history
Signed-off-by: Kevin <[email protected]>
  • Loading branch information
KPostOffice committed Sep 16, 2024
1 parent fb59ba6 commit fe75256
Show file tree
Hide file tree
Showing 5 changed files with 60 additions and 41 deletions.
45 changes: 32 additions & 13 deletions src/codeflare_sdk/cluster/auth.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,25 +119,14 @@ def login(self) -> str:
configuration.host = self.server
configuration.api_key["authorization"] = self.token

api_client = client.ApiClient(configuration)
if not self.skip_tls:
if self.ca_cert_path is None:
configuration.ssl_ca_cert = None
elif os.path.isfile(self.ca_cert_path):
print(
f"Authenticated with certificate located at {self.ca_cert_path}"
)
configuration.ssl_ca_cert = self.ca_cert_path
else:
raise FileNotFoundError(
f"Certificate file not found at {self.ca_cert_path}"
)
configuration.verify_ssl = True
_client_with_cert(api_client, self.ca_cert_path)
else:
urllib3.disable_warnings(urllib3.exceptions.InsecureRequestWarning)
print("Insecure request warnings have been disabled")
configuration.verify_ssl = False

api_client = client.ApiClient(configuration)
client.AuthenticationApi(api_client).get_api_group()
config_path = None
return "Logged into %s" % self.server
Expand Down Expand Up @@ -211,6 +200,36 @@ def config_check() -> str:
return config_path


def _client_with_cert(client: client.ApiClient, ca_cert_path: Optional[str] = None):
cert_path = _gen_ca_cert_path(ca_cert_path)
if os.path.isfile(cert_path):
print(f"Authenticated with certificate located at {ca_cert_path}")
client.configuration.ssl_ca_cert = ca_cert_path
else:
raise FileNotFoundError(f"Certificate file not found at {ca_cert_path}")


def _gen_ca_cert_path(ca_cert_path: str):
"""Gets the path to the default CA certificate file either through env config or default path"""
if ca_cert_path is not None:
return ca_cert_path
elif "CF_SDK_CA_CERT_PATH" in os.environ:
return os.environ.get("CF_SDK_CA_CERT_PATH")
elif os.path.exists(WORKBENCH_CA_CERT_PATH):
return WORKBENCH_CA_CERT_PATH
else:
return None


def get_api_client() -> client.ApiClient:
"This function should load the api client with defaults"
if api_client != None:
return api_client
to_return = client.ApiClient()
_client_with_cert(to_return)
return to_return


def api_config_handler() -> Optional[client.ApiClient]:
"""
This function is used to load the api client if the user has logged in
Expand Down
6 changes: 3 additions & 3 deletions src/codeflare_sdk/cluster/awload.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from kubernetes import client, config
from ..utils.kube_api_helpers import _kube_api_error_handling
from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client


class AWManager:
Expand Down Expand Up @@ -59,7 +59,7 @@ def submit(self) -> None:
"""
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.create_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -84,7 +84,7 @@ def remove(self) -> None:

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down
38 changes: 19 additions & 19 deletions src/codeflare_sdk/cluster/cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@
from kubernetes import config
from ray.job_submission import JobSubmissionClient

from .auth import config_check, api_config_handler
from .auth import config_check, get_api_client
from ..utils import pretty_print
from ..utils.generate_yaml import (
generate_appwrapper,
Expand Down Expand Up @@ -74,7 +74,7 @@ def __init__(self, config: ClusterConfiguration):

@property
def _client_headers(self):
k8_client = api_config_handler() or client.ApiClient()
k8_client = get_api_client()
return {
"Authorization": k8_client.configuration.get_api_key_with_prefix(
"authorization"
Expand All @@ -89,7 +89,7 @@ def _client_verify_tls(self):

@property
def job_client(self):
k8client = api_config_handler() or client.ApiClient()
k8client = get_api_client()
if self._job_submission_client:
return self._job_submission_client
if is_openshift_cluster():
Expand Down Expand Up @@ -135,7 +135,7 @@ def up(self):

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
if self.config.write_to_file:
with open(self.app_wrapper_yaml) as f:
Expand All @@ -162,7 +162,7 @@ def up(self):
return _kube_api_error_handling(e)

def _throw_for_no_raycluster(self):
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
try:
api_instance.list_namespaced_custom_object(
group="ray.io",
Expand All @@ -189,7 +189,7 @@ def down(self):
self._throw_for_no_raycluster()
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
if self.config.appwrapper:
api_instance.delete_namespaced_custom_object(
group="workload.codeflare.dev",
Expand Down Expand Up @@ -344,7 +344,7 @@ def cluster_dashboard_uri(self) -> str:
config_check()
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -366,7 +366,7 @@ def cluster_dashboard_uri(self) -> str:
return f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(self.config.namespace)
except Exception as e: # pragma no cover
return _kube_api_error_handling(e)
Expand Down Expand Up @@ -546,7 +546,7 @@ def list_all_queued(


def get_current_namespace(): # pragma: no cover
if api_config_handler() != None:
if get_api_client() != None:
if os.path.isfile("/var/run/secrets/kubernetes.io/serviceaccount/namespace"):
try:
file = open(
Expand Down Expand Up @@ -591,7 +591,7 @@ def get_cluster(
):
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -646,7 +646,7 @@ def _create_resources(yamls, namespace: str, api_instance: client.CustomObjectsA
def _check_aw_exists(name: str, namespace: str) -> bool:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -673,7 +673,7 @@ def _get_ingress_domain(self): # pragma: no cover

if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())

routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
Expand All @@ -692,7 +692,7 @@ def _get_ingress_domain(self): # pragma: no cover
domain = route["spec"]["host"]
else:
try:
api_client = client.NetworkingV1Api(api_config_handler())
api_client = client.NetworkingV1Api(get_api_client())
ingresses = api_client.list_namespaced_ingress(namespace)
except Exception as e: # pragma: no cover
return _kube_api_error_handling(e)
Expand All @@ -706,7 +706,7 @@ def _get_ingress_domain(self): # pragma: no cover
def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand All @@ -725,7 +725,7 @@ def _app_wrapper_status(name, namespace="default") -> Optional[AppWrapper]:
def _ray_cluster_status(name, namespace="default") -> Optional[RayCluster]:
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand All @@ -747,7 +747,7 @@ def _get_ray_clusters(
list_of_clusters = []
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
rcs = api_instance.list_namespaced_custom_object(
group="ray.io",
version="v1",
Expand Down Expand Up @@ -776,7 +776,7 @@ def _get_app_wrappers(

try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
aws = api_instance.list_namespaced_custom_object(
group="workload.codeflare.dev",
version="v1beta2",
Expand Down Expand Up @@ -805,7 +805,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = None
if is_openshift_cluster():
try:
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
routes = api_instance.list_namespaced_custom_object(
group="route.openshift.io",
version="v1",
Expand All @@ -824,7 +824,7 @@ def _map_to_ray_cluster(rc) -> Optional[RayCluster]:
dashboard_url = f"{protocol}://{route['spec']['host']}"
else:
try:
api_instance = client.NetworkingV1Api(api_config_handler())
api_instance = client.NetworkingV1Api(get_api_client())
ingresses = api_instance.list_namespaced_ingress(
rc["metadata"]["namespace"]
)
Expand Down
4 changes: 2 additions & 2 deletions src/codeflare_sdk/utils/generate_cert.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@
from cryptography import x509
from cryptography.x509.oid import NameOID
import datetime
from ..cluster.auth import config_check, api_config_handler
from ..cluster.auth import config_check, get_api_client
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling

Expand Down Expand Up @@ -103,7 +103,7 @@ def generate_tls_cert(cluster_name, namespace, days=30):
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.key"}}'
# oc get secret ca-secret-<cluster-name> -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
config_check()
v1 = client.CoreV1Api(api_config_handler())
v1 = client.CoreV1Api(get_api_client())

# Secrets have a suffix appended to the end so we must list them and gather the secret that includes cluster_name-ca-secret-
secret_name = get_secret_name(cluster_name, namespace, v1)
Expand Down
8 changes: 4 additions & 4 deletions src/codeflare_sdk/utils/generate_yaml.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@
import uuid
from kubernetes import client, config
from .kube_api_helpers import _kube_api_error_handling
from ..cluster.auth import api_config_handler, config_check
from ..cluster.auth import get_api_client, config_check
from os import urandom
from base64 import b64encode
from urllib3.util import parse_url
Expand Down Expand Up @@ -57,7 +57,7 @@ def gen_names(name):
def is_openshift_cluster():
try:
config_check()
for api in client.ApisApi(api_config_handler()).get_api_versions().groups:
for api in client.ApisApi(get_api_client()).get_api_versions().groups:
for v in api.versions:
if "route.openshift.io/v1" in v.group_version:
return True
Expand Down Expand Up @@ -235,7 +235,7 @@ def get_default_kueue_name(namespace: str):
# If the local queue is set, use it. Otherwise, try to use the default queue.
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand All @@ -261,7 +261,7 @@ def local_queue_exists(namespace: str, local_queue_name: str):
# get all local queues in the namespace
try:
config_check()
api_instance = client.CustomObjectsApi(api_config_handler())
api_instance = client.CustomObjectsApi(get_api_client())
local_queues = api_instance.list_namespaced_custom_object(
group="kueue.x-k8s.io",
version="v1beta1",
Expand Down

0 comments on commit fe75256

Please sign in to comment.