diff --git a/demo-notebooks/interactive/local_interactive.ipynb b/demo-notebooks/interactive/local_interactive.ipynb new file mode 100644 index 000000000..88a6ccd58 --- /dev/null +++ b/demo-notebooks/interactive/local_interactive.ipynb @@ -0,0 +1,358 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9a44568b-61ef-41c7-8ad1-9a3b128f03a7", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "# Import pieces from codeflare-sdk\n", + "from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration\n", + "from codeflare_sdk.cluster.auth import TokenAuthentication" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "2cc66278", + "metadata": {}, + "outputs": [], + "source": [ + "# Create authentication object and log in to desired user account (if not already authenticated)\n", + "auth = TokenAuthentication(\n", + " token = \"XXXX\",\n", + " server = \"XXXX\",\n", + " skip_tls = False\n", + ")\n", + "auth.login()" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "4364ac2e-dd10-4d30-ba66-12708daefb3f", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Written to: hfgputest-1.yaml\n" + ] + } + ], + "source": [ + "# Create our cluster and submit appwrapper\n", + "namespace = \"default\"\n", + "cluster_name = \"hfgputest-1\"\n", + "local_interactive = True\n", + "\n", + "cluster = Cluster(ClusterConfiguration(local_interactive=local_interactive, namespace=namespace, name=cluster_name, min_worker=1, max_worker=1, min_cpus=1, max_cpus=1, min_memory=4, max_memory=4, gpu=0, instascale=False, machine_types=[\"m5.xlarge\", \"p3.8xlarge\"]))" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "69968140-15e6-482f-9529-82b0cd19524b", + "metadata": { + "tags": [] + }, + "outputs": [], + "source": [ + "cluster.up()" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "e20f9982-f671-460b-8c22-3d62e101fed9", + "metadata": { + "tags": [] + }, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Waiting for requested resources to be set up...\n", + "Requested cluster up and running!\n" + ] + } + ], + "source": [ + "cluster.wait_ready()" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "12eef53c", + "metadata": {}, + "source": [ + "### Connect via the rayclient route" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "cf1b749e-2335-42c2-b673-26768ec9895d", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "rayclient-hfgputest-1-default.apps.tedbig412.cp.fyre.ibm.com\n" + ] + } + ], + "source": [ + "import openshift as oc\n", + "from codeflare_sdk.utils import generate_cert\n", + "\n", + "if local_interactive:\n", + " generate_cert.generate_tls_cert(cluster_name, namespace)\n", + " generate_cert.export_env(cluster_name, namespace)\n", + "\n", + "with oc.project(namespace):\n", + " routes=oc.selector(\"route\").objects()\n", + " rayclient_url=\"\"\n", + " for r in routes:\n", + " if \"rayclient\" in r.name():\n", + " rayclient_url=r.model.spec.host\n", + "print(rayclient_url)" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "9483bb98-33b3-4beb-9b15-163d7e76c1d7", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-31 14:12:37,816\tINFO client_builder.py:251 -- Passing the following kwargs to ray.init() on the server: logging_level\n", + "2023-05-31 14:12:37,820\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.IDLE\n", + "2023-05-31 14:12:38,034\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.CONNECTING\n", + "2023-05-31 14:12:38,246\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.READY\n", + "2023-05-31 14:12:38,290\tDEBUG worker.py:807 -- Pinging server.\n", + "2023-05-31 14:12:40,521\tDEBUG worker.py:640 -- Retaining 00ffffffffffffffffffffffffffffffffffffff0100000001000000\n", + "2023-05-31 14:12:40,523\tDEBUG worker.py:564 -- Scheduling task get_dashboard_url 0 b'\\x00\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00'\n", + "2023-05-31 14:12:40,535\tDEBUG worker.py:640 -- Retaining c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000\n", + "2023-05-31 14:12:41,379\tDEBUG worker.py:636 -- Releasing c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000\n" + ] + }, + { + "data": { + "text/html": [ + "
\n", + "
\n", + "

Ray

\n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + " \n", + "\n", + "\n", + "
Python version:3.8.13
Ray version: 2.1.0
Dashboard:http://10.254.12.141:8265
\n", + "
\n", + "
\n" + ], + "text/plain": [ + "ClientContext(dashboard_url='10.254.12.141:8265', python_version='3.8.13', ray_version='2.1.0', ray_commit='23f34d948dae8de9b168667ab27e6cf940b3ae85', protocol_version='2022-10-05', _num_clients=1, _context_to_restore=)" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "import ray\n", + "\n", + "ray.shutdown()\n", + "ray.init(address=f\"ray://{rayclient_url}\", logging_level=\"DEBUG\")" + ] + }, + { + "cell_type": "code", + "execution_count": 13, + "id": "3436eb4a-217c-4109-a3c3-309fda7e2442", + "metadata": {}, + "outputs": [], + "source": [ + "import math\n", + "import ray\n", + "\n", + "@ray.remote\n", + "def heavy_calculation_part(num_iterations):\n", + " result = 0.0\n", + " for i in range(num_iterations):\n", + " for j in range(num_iterations):\n", + " for k in range(num_iterations):\n", + " result += math.sin(i) * math.cos(j) * math.tan(k)\n", + " return result\n", + "@ray.remote\n", + "def heavy_calculation(num_iterations):\n", + " results = ray.get([heavy_calculation_part.remote(num_iterations//30) for _ in range(30)])\n", + " return sum(results)\n" + ] + }, + { + "cell_type": "code", + "execution_count": 14, + "id": "5cca1874-2be3-4631-ae48-9adfa45e3af3", + "metadata": { + "scrolled": true, + "tags": [] + }, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-31 14:13:29,868\tDEBUG worker.py:640 -- Retaining 00ffffffffffffffffffffffffffffffffffffff0100000002000000\n", + "2023-05-31 14:13:29,870\tDEBUG worker.py:564 -- Scheduling task heavy_calculation 0 b'\\x00\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\x01\\x00\\x00\\x00\\x02\\x00\\x00\\x00'\n" + ] + } + ], + "source": [ + "ref = heavy_calculation.remote(3000)" + ] + }, + { + "cell_type": "code", + "execution_count": 15, + "id": "01172c29-e8bf-41ef-8db5-eccb07906111", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-31 14:13:32,643\tDEBUG worker.py:640 -- Retaining 16310a0f0a45af5cffffffffffffffffffffffff0100000001000000\n", + "2023-05-31 14:13:34,677\tDEBUG worker.py:439 -- Internal retry for get [ClientObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000)]\n" + ] + }, + { + "data": { + "text/plain": [ + "1789.4644387076714" + ] + }, + "execution_count": 15, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "ray.get(ref)" + ] + }, + { + "cell_type": "code", + "execution_count": 16, + "id": "9e79b547-a457-4232-b77d-19147067b972", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "2023-05-31 14:13:37,659\tDEBUG dataclient.py:287 -- Got unawaited response connection_cleanup {\n", + "}\n", + "\n", + "2023-05-31 14:13:38,681\tDEBUG dataclient.py:278 -- Shutting down data channel.\n" + ] + } + ], + "source": [ + "ray.cancel(ref)\n", + "ray.shutdown()" + ] + }, + { + "cell_type": "code", + "execution_count": 17, + "id": "2c198f1f-68bf-43ff-a148-02b5cb000ff2", + "metadata": {}, + "outputs": [], + "source": [ + "cluster.down()" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6879e471-a69f-4c74-9cec-a195cdead47c", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.13" + }, + "vscode": { + "interpreter": { + "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pyproject.toml b/pyproject.toml index d9e63835e..79b90cb37 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,6 +26,7 @@ rich = "^12.5" ray = {version = "2.1.0", extras = ["default"]} kubernetes = ">= 25.3.0, < 27" codeflare-torchx = "0.6.0.dev0" +cryptography = "40.0.2" [tool.poetry.group.docs] optional = true diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py index 31974b00f..0e0e73c06 100644 --- a/src/codeflare_sdk/cluster/cluster.py +++ b/src/codeflare_sdk/cluster/cluster.py @@ -84,6 +84,7 @@ def create_app_wrapper(self): instascale = self.config.instascale instance_types = self.config.machine_types env = self.config.envs + local_interactive = self.config.local_interactive return generate_appwrapper( name=name, namespace=namespace, @@ -98,6 +99,7 @@ def create_app_wrapper(self): instascale=instascale, instance_types=instance_types, env=env, + local_interactive=local_interactive, ) # creates a new cluster with the provided or default spec diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py index 25f129256..25392db75 100644 --- a/src/codeflare_sdk/cluster/config.py +++ b/src/codeflare_sdk/cluster/config.py @@ -48,3 +48,4 @@ class ClusterConfiguration: instascale: bool = False envs: dict = field(default_factory=dict) image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103" + local_interactive: bool = False diff --git a/src/codeflare_sdk/templates/base-template.yaml b/src/codeflare_sdk/templates/base-template.yaml index c99fd105d..f408df2e2 100644 --- a/src/codeflare_sdk/templates/base-template.yaml +++ b/src/codeflare_sdk/templates/base-template.yaml @@ -119,6 +119,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: "0" + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt name: ray-head image: rayproject/ray:latest imagePullPolicy: Always @@ -142,6 +150,37 @@ spec: cpu: 2 memory: "8G" nvidia.com/gpu: 0 + volumeMounts: + - name: ca-vol + mountPath: "/home/ray/workspace/ca" + readOnly: true + - name: server-cert + mountPath: "/home/ray/workspace/tls" + readOnly: true + initContainers: + - command: + - sh + - -c + - cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = rayclient-deployment-name-$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).server-name">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext + image: rayproject/ray:2.5.0 + name: create-cert + # securityContext: + # runAsUser: 1000 + # runAsGroup: 1000 + volumeMounts: + - name: ca-vol + mountPath: "/home/ray/workspace/ca" + readOnly: true + - name: server-cert + mountPath: "/home/ray/workspace/tls" + readOnly: false + volumes: + - name: ca-vol + secret: + secretName: ca-secret-deployment-name + optional: false + - name: server-cert + emptyDir: {} workerGroupSpecs: # the pod replicas in this group typed worker - replicas: 3 @@ -187,6 +226,22 @@ spec: - name: init-myservice image: busybox:1.28 command: ['sh', '-c', "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done"] + - name: create-cert + image: rayproject/ray:2.5.0 + command: + - sh + - -c + - cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext + # securityContext: + # runAsUser: 1000 + # runAsGroup: 1000 + volumeMounts: + - name: ca-vol + mountPath: "/home/ray/workspace/ca" + readOnly: true + - name: server-cert + mountPath: "/home/ray/workspace/tls" + readOnly: false containers: - name: machine-learning # must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc' image: rayproject/ray:latest @@ -195,6 +250,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: "0" + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt # environment variables to set in the container.Optional. # Refer to https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/ lifecycle: @@ -210,6 +273,20 @@ spec: cpu: "2" memory: "12G" nvidia.com/gpu: "1" + volumeMounts: + - name: ca-vol + mountPath: "/home/ray/workspace/ca" + readOnly: true + - name: server-cert + mountPath: "/home/ray/workspace/tls" + readOnly: true + volumes: + - name: ca-vol + secret: + secretName: ca-secret-deployment-name + optional: false + - name: server-cert + emptyDir: {} - replica: 1 generictemplate: kind: Route @@ -226,3 +303,33 @@ spec: name: deployment-name-head-svc port: targetPort: dashboard + - replicas: 1 + generictemplate: + apiVersion: route.openshift.io/v1 + kind: Route + metadata: + name: rayclient-deployment-name + namespace: default + labels: + # allows me to return name of service that Ray operator creates + odh-ray-cluster-service: deployment-name-head-svc + spec: + port: + targetPort: client + tls: + termination: passthrough + to: + kind: Service + name: deployment-name-head-svc + - replicas: 1 + generictemplate: + apiVersion: v1 + data: + ca.crt: generated_crt + ca.key: generated_key + kind: Secret + metadata: + name: ca-secret-deployment-name + labels: + # allows me to return name of service that Ray operator creates + odh-ray-cluster-service: deployment-name-head-svc diff --git a/src/codeflare_sdk/utils/generate_cert.py b/src/codeflare_sdk/utils/generate_cert.py new file mode 100644 index 000000000..2d73621b8 --- /dev/null +++ b/src/codeflare_sdk/utils/generate_cert.py @@ -0,0 +1,161 @@ +# Copyright 2022 IBM, Red Hat +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import base64 +import os +from cryptography.hazmat.primitives import serialization, hashes +from cryptography.hazmat.primitives.asymmetric import rsa +from cryptography import x509 +from cryptography.x509.oid import NameOID +import datetime +from kubernetes import client, config + + +def generate_ca_cert(days: int = 30): + # Generate base64 encoded ca.key and ca.cert + # Similar to: + # openssl req -x509 -nodes -newkey rsa:2048 -keyout ca.key -days 1826 -out ca.crt -subj '/CN=root-ca' + # base64 -i ca.crt -i ca.key + + private_key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + key = base64.b64encode( + private_key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + ).decode("utf-8") + + # Generate Certificate + one_day = datetime.timedelta(1, 0, 0) + public_key = private_key.public_key() + builder = ( + x509.CertificateBuilder() + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"), + ] + ) + ) + .issuer_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"), + ] + ) + ) + .not_valid_before(datetime.datetime.today() - one_day) + .not_valid_after(datetime.datetime.today() + (one_day * days)) + .serial_number(x509.random_serial_number()) + .public_key(public_key) + ) + certificate = base64.b64encode( + builder.sign(private_key=private_key, algorithm=hashes.SHA256()).public_bytes( + serialization.Encoding.PEM + ) + ).decode("utf-8") + return key, certificate + + +def generate_tls_cert(cluster_name, namespace, days=30): + # Create a folder tls-- and store three files: ca.crt, tls.crt, and tls.key + tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}") + if not os.path.exists(tls_dir): + os.makedirs(tls_dir) + + # Similar to: + # oc get secret ca-secret- -o template='{{index .data "ca.key"}}' + # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt + config.load_kube_config() + v1 = client.CoreV1Api() + secret = v1.read_namespaced_secret(f"ca-secret-{cluster_name}", namespace).data + ca_cert = secret.get("ca.crt") + ca_key = secret.get("ca.key") + + with open(os.path.join(tls_dir, "ca.crt"), "w") as f: + f.write(base64.b64decode(ca_cert).decode("utf-8")) + + # Generate tls.key and signed tls.cert locally for ray client + # Similar to running these commands: + # openssl req -nodes -newkey rsa:2048 -keyout ${TLSDIR}/tls.key -out ${TLSDIR}/tls.csr -subj '/CN=local' + # cat <${TLSDIR}/domain.ext + # authorityKeyIdentifier=keyid,issuer + # basicConstraints=CA:FALSE + # subjectAltName = @alt_names + # [alt_names] + # DNS.1 = 127.0.0.1 + # DNS.2 = localhost + # EOF + # openssl x509 -req -CA ${TLSDIR}/ca.crt -CAkey ${TLSDIR}/ca.key -in ${TLSDIR}/tls.csr -out ${TLSDIR}/tls.crt -days 365 -CAcreateserial -extfile ${TLSDIR}/domain.ext + key = rsa.generate_private_key( + public_exponent=65537, + key_size=2048, + ) + + tls_key = key.private_bytes( + serialization.Encoding.PEM, + serialization.PrivateFormat.PKCS8, + serialization.NoEncryption(), + ) + with open(os.path.join(tls_dir, "tls.key"), "w") as f: + f.write(tls_key.decode("utf-8")) + + one_day = datetime.timedelta(1, 0, 0) + tls_cert = ( + x509.CertificateBuilder() + .issuer_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"), + ] + ) + ) + .subject_name( + x509.Name( + [ + x509.NameAttribute(NameOID.COMMON_NAME, "local"), + ] + ) + ) + .public_key(key.public_key()) + .not_valid_before(datetime.datetime.today() - one_day) + .not_valid_after(datetime.datetime.today() + (one_day * days)) + .serial_number(x509.random_serial_number()) + .add_extension( + x509.SubjectAlternativeName( + [x509.DNSName("localhost"), x509.DNSName("127.0.0.1")] + ), + False, + ) + .sign( + serialization.load_pem_private_key(base64.b64decode(ca_key), None), + hashes.SHA256(), + ) + ) + + with open(os.path.join(tls_dir, "tls.crt"), "w") as f: + f.write(tls_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8")) + + +def export_env(cluster_name, namespace): + tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}") + os.environ["RAY_USE_TLS"] = "1" + os.environ["RAY_TLS_SERVER_CERT"] = os.path.join(tls_dir, "tls.crt") + os.environ["RAY_TLS_SERVER_KEY"] = os.path.join(tls_dir, "tls.key") + os.environ["RAY_TLS_CA_CERT"] = os.path.join(tls_dir, "ca.crt") diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py index 36757a2d6..fffc4fa9a 100755 --- a/src/codeflare_sdk/utils/generate_yaml.py +++ b/src/codeflare_sdk/utils/generate_yaml.py @@ -21,6 +21,7 @@ import sys import argparse import uuid +import openshift as oc def read_template(template): @@ -50,6 +51,16 @@ def update_dashboard_route(route_item, cluster_name, namespace): spec["to"]["name"] = f"{cluster_name}-head-svc" +# ToDo: refactor the update_x_route() functions +def update_rayclient_route(route_item, cluster_name, namespace): + metadata = route_item.get("generictemplate", {}).get("metadata") + metadata["name"] = f"rayclient-{cluster_name}" + metadata["namespace"] = namespace + metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc" + spec = route_item.get("generictemplate", {}).get("spec") + spec["to"]["name"] = f"{cluster_name}-head-svc" + + def update_names(yaml, item, appwrapper_name, cluster_name, namespace): metadata = yaml.get("metadata") metadata["name"] = appwrapper_name @@ -191,6 +202,78 @@ def update_nodes( update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu) +def update_ca_secret(ca_secret_item, cluster_name, namespace): + from . import generate_cert + + metadata = ca_secret_item.get("generictemplate", {}).get("metadata") + metadata["name"] = f"ca-secret-{cluster_name}" + metadata["namespace"] = namespace + metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc" + data = ca_secret_item.get("generictemplate", {}).get("data") + data["ca.key"], data["ca.crt"] = generate_cert.generate_ca_cert(365) + + +def enable_local_interactive(resources, cluster_name, namespace): + rayclient_route_item = resources["resources"].get("GenericItems")[2] + ca_secret_item = resources["resources"].get("GenericItems")[3] + item = resources["resources"].get("GenericItems")[0] + update_rayclient_route(rayclient_route_item, cluster_name, namespace) + update_ca_secret(ca_secret_item, cluster_name, namespace) + # update_ca_secret_volumes + item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]["volumes"][0][ + "secret" + ]["secretName"] = f"ca-secret-{cluster_name}" + item["generictemplate"]["spec"]["workerGroupSpecs"][0]["template"]["spec"][ + "volumes" + ][0]["secret"]["secretName"] = f"ca-secret-{cluster_name}" + # update_tls_env + item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]["containers"][ + 0 + ]["env"][1]["value"] = "1" + item["generictemplate"]["spec"]["workerGroupSpecs"][0]["template"]["spec"][ + "containers" + ][0]["env"][1]["value"] = "1" + # update_init_container + command = item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"][ + "initContainers" + ][0].get("command")[2] + + command = command.replace("deployment-name", cluster_name) + + server_name = ( + oc.whoami("--show-server").split(":")[1].split("//")[1].replace("api", "apps") + ) + + command = command.replace("server-name", server_name) + + item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"][ + "initContainers" + ][0].get("command")[2] = command + + +def disable_raycluster_tls(resources): + del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][ + "template" + ]["spec"]["volumes"] + del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][ + "template" + ]["spec"]["containers"][0]["volumeMounts"] + del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][ + "template" + ]["spec"]["initContainers"] + del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][ + "template" + ]["spec"]["volumes"] + del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][ + "template" + ]["spec"]["containers"][0]["volumeMounts"] + del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][ + "template" + ]["spec"]["initContainers"][1] + del resources["GenericItems"][3] # rayclient route + del resources["GenericItems"][2] # ca-secret + + def write_user_appwrapper(user_yaml, output_file_name): with open(output_file_name, "w") as outfile: yaml.dump(user_yaml, outfile, default_flow_style=False) @@ -211,6 +294,7 @@ def generate_appwrapper( instascale: bool, instance_types: list, env, + local_interactive: bool, ): user_yaml = read_template(template) appwrapper_name, cluster_name = gen_names(name) @@ -236,6 +320,10 @@ def generate_appwrapper( env, ) update_dashboard_route(route_item, cluster_name, namespace) + if local_interactive: + enable_local_interactive(resources, cluster_name, namespace) + else: + disable_raycluster_tls(resources["resources"]) outfile = appwrapper_name + ".yaml" write_user_appwrapper(user_yaml, outfile) return outfile @@ -315,6 +403,12 @@ def main(): # pragma: no cover default="default", help="Set the kubernetes namespace you want to deploy your cluster to. Default. If left blank, uses the 'default' namespace", ) + parser.add_argument( + "--local-interactive", + required=False, + default=False, + help="Enable local interactive mode", + ) args = parser.parse_args() name = args.name @@ -329,6 +423,7 @@ def main(): # pragma: no cover instascale = args.instascale instance_types = args.instance_types namespace = args.namespace + local_interactive = args.local_interactive env = {} outfile = generate_appwrapper( @@ -344,6 +439,7 @@ def main(): # pragma: no cover image, instascale, instance_types, + local_interactive, env, ) return outfile diff --git a/tests/test-case-cmd.yaml b/tests/test-case-cmd.yaml index 450ec9668..d82096e8f 100644 --- a/tests/test-case-cmd.yaml +++ b/tests/test-case-cmd.yaml @@ -62,6 +62,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: '0' + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt image: rayproject/ray:latest imagePullPolicy: Always lifecycle: @@ -110,6 +118,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: '0' + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt image: rayproject/ray:latest lifecycle: preStop: diff --git a/tests/test-case.yaml b/tests/test-case.yaml index 133a22229..0d85428df 100644 --- a/tests/test-case.yaml +++ b/tests/test-case.yaml @@ -73,6 +73,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: '0' + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt image: ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103 imagePullPolicy: Always lifecycle: @@ -130,6 +138,14 @@ spec: valueFrom: fieldRef: fieldPath: status.podIP + - name: RAY_USE_TLS + value: '0' + - name: RAY_TLS_SERVER_CERT + value: /home/ray/workspace/tls/server.crt + - name: RAY_TLS_SERVER_KEY + value: /home/ray/workspace/tls/server.key + - name: RAY_TLS_CA_CERT + value: /home/ray/workspace/tls/ca.crt image: ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103 lifecycle: preStop: diff --git a/tests/unit_test.py b/tests/unit_test.py index 47c0d43a8..ead3521c8 100644 --- a/tests/unit_test.py +++ b/tests/unit_test.py @@ -21,6 +21,7 @@ parent = Path(__file__).resolve().parents[1] sys.path.append(str(parent) + "/src") +from kubernetes import client from codeflare_sdk.cluster.awload import AWManager from codeflare_sdk.cluster.cluster import ( Cluster, @@ -54,6 +55,12 @@ DDPJob, torchx_runner, ) +from codeflare_sdk.utils.generate_cert import ( + generate_ca_cert, + generate_tls_cert, + export_env, +) + import openshift from openshift import OpenShiftPythonException from openshift.selector import Selector @@ -1980,6 +1987,88 @@ def test_AWManager_submit_remove(mocker, capsys): assert testaw.submitted == False +from cryptography.x509 import load_pem_x509_certificate +import base64 +from cryptography.hazmat.primitives.serialization import ( + load_pem_private_key, + Encoding, + PublicFormat, +) + + +def test_generate_ca_cert(): + """ + test the function codeflare_sdk.utils.generate_ca_cert generates the correct outputs + """ + key, certificate = generate_ca_cert() + cert = load_pem_x509_certificate(base64.b64decode(certificate)) + private_pub_key_bytes = ( + load_pem_private_key(base64.b64decode(key), password=None) + .public_key() + .public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo) + ) + cert_pub_key_bytes = cert.public_key().public_bytes( + Encoding.PEM, PublicFormat.SubjectPublicKeyInfo + ) + assert type(key) == str + assert type(certificate) == str + # Veirfy ca.cert is self signed + assert cert.verify_directly_issued_by(cert) == None + # Verify cert has the public key bytes from the private key + assert cert_pub_key_bytes == private_pub_key_bytes + + +def secret_ca_retreival(secret_name, namespace): + ca_private_key_bytes, ca_cert = generate_ca_cert() + data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes} + assert secret_name == "ca-secret-cluster" + assert namespace == "namespace" + return client.models.V1Secret(data=data) + + +def test_generate_tls_cert(mocker): + """ + test the function codeflare_sdk.utils.generate_ca_cert generates the correct outputs + """ + mocker.patch("kubernetes.config.load_kube_config", return_value="ignore") + mocker.patch( + "kubernetes.client.CoreV1Api.read_namespaced_secret", + side_effect=secret_ca_retreival, + ) + + generate_tls_cert("cluster", "namespace") + assert os.path.exists("tls-cluster-namespace") + assert os.path.exists(os.path.join("tls-cluster-namespace", "ca.crt")) + assert os.path.exists(os.path.join("tls-cluster-namespace", "tls.crt")) + assert os.path.exists(os.path.join("tls-cluster-namespace", "tls.key")) + + # verify the that the signed tls.crt is issued by the ca_cert (root cert) + with open(os.path.join("tls-cluster-namespace", "tls.crt"), "r") as f: + tls_cert = load_pem_x509_certificate(f.read().encode("utf-8")) + with open(os.path.join("tls-cluster-namespace", "ca.crt"), "r") as f: + root_cert = load_pem_x509_certificate(f.read().encode("utf-8")) + assert tls_cert.verify_directly_issued_by(root_cert) == None + + +def test_export_env(): + """ + test the function codeflare_sdk.utils.export_ev generates the correct outputs + """ + tls_dir = "cluster" + ns = "namespace" + export_env(tls_dir, ns) + assert os.environ["RAY_USE_TLS"] == "1" + assert os.environ["RAY_TLS_SERVER_CERT"] == os.path.join( + os.getcwd(), f"tls-{tls_dir}-{ns}", "tls.crt" + ) + assert os.environ["RAY_TLS_SERVER_KEY"] == os.path.join( + os.getcwd(), f"tls-{tls_dir}-{ns}", "tls.key" + ) + assert os.environ["RAY_TLS_CA_CERT"] == os.path.join( + os.getcwd(), f"tls-{tls_dir}-{ns}", "ca.crt" + ) + + # Make sure to keep this function and the following function at the end of the file def test_cmd_line_generation(): os.system(