diff --git a/demo-notebooks/interactive/local_interactive.ipynb b/demo-notebooks/interactive/local_interactive.ipynb
new file mode 100644
index 000000000..88a6ccd58
--- /dev/null
+++ b/demo-notebooks/interactive/local_interactive.ipynb
@@ -0,0 +1,358 @@
+{
+ "cells": [
+ {
+ "cell_type": "code",
+ "execution_count": 1,
+ "id": "9a44568b-61ef-41c7-8ad1-9a3b128f03a7",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "# Import pieces from codeflare-sdk\n",
+ "from codeflare_sdk.cluster.cluster import Cluster, ClusterConfiguration\n",
+ "from codeflare_sdk.cluster.auth import TokenAuthentication"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "2cc66278",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "# Create authentication object and log in to desired user account (if not already authenticated)\n",
+ "auth = TokenAuthentication(\n",
+ " token = \"XXXX\",\n",
+ " server = \"XXXX\",\n",
+ " skip_tls = False\n",
+ ")\n",
+ "auth.login()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 2,
+ "id": "4364ac2e-dd10-4d30-ba66-12708daefb3f",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Written to: hfgputest-1.yaml\n"
+ ]
+ }
+ ],
+ "source": [
+ "# Create our cluster and submit appwrapper\n",
+ "namespace = \"default\"\n",
+ "cluster_name = \"hfgputest-1\"\n",
+ "local_interactive = True\n",
+ "\n",
+ "cluster = Cluster(ClusterConfiguration(local_interactive=local_interactive, namespace=namespace, name=cluster_name, min_worker=1, max_worker=1, min_cpus=1, max_cpus=1, min_memory=4, max_memory=4, gpu=0, instascale=False, machine_types=[\"m5.xlarge\", \"p3.8xlarge\"]))"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 3,
+ "id": "69968140-15e6-482f-9529-82b0cd19524b",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [],
+ "source": [
+ "cluster.up()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 4,
+ "id": "e20f9982-f671-460b-8c22-3d62e101fed9",
+ "metadata": {
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "Waiting for requested resources to be set up...\n",
+ "Requested cluster up and running!\n"
+ ]
+ }
+ ],
+ "source": [
+ "cluster.wait_ready()"
+ ]
+ },
+ {
+ "attachments": {},
+ "cell_type": "markdown",
+ "id": "12eef53c",
+ "metadata": {},
+ "source": [
+ "### Connect via the rayclient route"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 11,
+ "id": "cf1b749e-2335-42c2-b673-26768ec9895d",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stdout",
+ "output_type": "stream",
+ "text": [
+ "rayclient-hfgputest-1-default.apps.tedbig412.cp.fyre.ibm.com\n"
+ ]
+ }
+ ],
+ "source": [
+ "import openshift as oc\n",
+ "from codeflare_sdk.utils import generate_cert\n",
+ "\n",
+ "if local_interactive:\n",
+ " generate_cert.generate_tls_cert(cluster_name, namespace)\n",
+ " generate_cert.export_env(cluster_name, namespace)\n",
+ "\n",
+ "with oc.project(namespace):\n",
+ " routes=oc.selector(\"route\").objects()\n",
+ " rayclient_url=\"\"\n",
+ " for r in routes:\n",
+ " if \"rayclient\" in r.name():\n",
+ " rayclient_url=r.model.spec.host\n",
+ "print(rayclient_url)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 12,
+ "id": "9483bb98-33b3-4beb-9b15-163d7e76c1d7",
+ "metadata": {
+ "scrolled": true,
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-05-31 14:12:37,816\tINFO client_builder.py:251 -- Passing the following kwargs to ray.init() on the server: logging_level\n",
+ "2023-05-31 14:12:37,820\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.IDLE\n",
+ "2023-05-31 14:12:38,034\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.CONNECTING\n",
+ "2023-05-31 14:12:38,246\tDEBUG worker.py:378 -- client gRPC channel state change: ChannelConnectivity.READY\n",
+ "2023-05-31 14:12:38,290\tDEBUG worker.py:807 -- Pinging server.\n",
+ "2023-05-31 14:12:40,521\tDEBUG worker.py:640 -- Retaining 00ffffffffffffffffffffffffffffffffffffff0100000001000000\n",
+ "2023-05-31 14:12:40,523\tDEBUG worker.py:564 -- Scheduling task get_dashboard_url 0 b'\\x00\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\x01\\x00\\x00\\x00\\x01\\x00\\x00\\x00'\n",
+ "2023-05-31 14:12:40,535\tDEBUG worker.py:640 -- Retaining c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000\n",
+ "2023-05-31 14:12:41,379\tDEBUG worker.py:636 -- Releasing c8ef45ccd0112571ffffffffffffffffffffffff0100000001000000\n"
+ ]
+ },
+ {
+ "data": {
+ "text/html": [
+ "
\n",
+ "
\n",
+ "
Ray
\n",
+ "
\n",
+ "
\n",
+ " \n",
+ " Python version: | \n",
+ " 3.8.13 | \n",
+ "
\n",
+ " \n",
+ " Ray version: | \n",
+ " 2.1.0 | \n",
+ "
\n",
+ " \n",
+ " Dashboard: | \n",
+ " http://10.254.12.141:8265 | \n",
+ "
\n",
+ "\n",
+ "
\n",
+ "
\n",
+ "
\n"
+ ],
+ "text/plain": [
+ "ClientContext(dashboard_url='10.254.12.141:8265', python_version='3.8.13', ray_version='2.1.0', ray_commit='23f34d948dae8de9b168667ab27e6cf940b3ae85', protocol_version='2022-10-05', _num_clients=1, _context_to_restore=)"
+ ]
+ },
+ "execution_count": 12,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "import ray\n",
+ "\n",
+ "ray.shutdown()\n",
+ "ray.init(address=f\"ray://{rayclient_url}\", logging_level=\"DEBUG\")"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 13,
+ "id": "3436eb4a-217c-4109-a3c3-309fda7e2442",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "import math\n",
+ "import ray\n",
+ "\n",
+ "@ray.remote\n",
+ "def heavy_calculation_part(num_iterations):\n",
+ " result = 0.0\n",
+ " for i in range(num_iterations):\n",
+ " for j in range(num_iterations):\n",
+ " for k in range(num_iterations):\n",
+ " result += math.sin(i) * math.cos(j) * math.tan(k)\n",
+ " return result\n",
+ "@ray.remote\n",
+ "def heavy_calculation(num_iterations):\n",
+ " results = ray.get([heavy_calculation_part.remote(num_iterations//30) for _ in range(30)])\n",
+ " return sum(results)\n"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 14,
+ "id": "5cca1874-2be3-4631-ae48-9adfa45e3af3",
+ "metadata": {
+ "scrolled": true,
+ "tags": []
+ },
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-05-31 14:13:29,868\tDEBUG worker.py:640 -- Retaining 00ffffffffffffffffffffffffffffffffffffff0100000002000000\n",
+ "2023-05-31 14:13:29,870\tDEBUG worker.py:564 -- Scheduling task heavy_calculation 0 b'\\x00\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\xff\\x01\\x00\\x00\\x00\\x02\\x00\\x00\\x00'\n"
+ ]
+ }
+ ],
+ "source": [
+ "ref = heavy_calculation.remote(3000)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 15,
+ "id": "01172c29-e8bf-41ef-8db5-eccb07906111",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-05-31 14:13:32,643\tDEBUG worker.py:640 -- Retaining 16310a0f0a45af5cffffffffffffffffffffffff0100000001000000\n",
+ "2023-05-31 14:13:34,677\tDEBUG worker.py:439 -- Internal retry for get [ClientObjectRef(16310a0f0a45af5cffffffffffffffffffffffff0100000001000000)]\n"
+ ]
+ },
+ {
+ "data": {
+ "text/plain": [
+ "1789.4644387076714"
+ ]
+ },
+ "execution_count": 15,
+ "metadata": {},
+ "output_type": "execute_result"
+ }
+ ],
+ "source": [
+ "ray.get(ref)"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 16,
+ "id": "9e79b547-a457-4232-b77d-19147067b972",
+ "metadata": {},
+ "outputs": [
+ {
+ "name": "stderr",
+ "output_type": "stream",
+ "text": [
+ "2023-05-31 14:13:37,659\tDEBUG dataclient.py:287 -- Got unawaited response connection_cleanup {\n",
+ "}\n",
+ "\n",
+ "2023-05-31 14:13:38,681\tDEBUG dataclient.py:278 -- Shutting down data channel.\n"
+ ]
+ }
+ ],
+ "source": [
+ "ray.cancel(ref)\n",
+ "ray.shutdown()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": 17,
+ "id": "2c198f1f-68bf-43ff-a148-02b5cb000ff2",
+ "metadata": {},
+ "outputs": [],
+ "source": [
+ "cluster.down()"
+ ]
+ },
+ {
+ "cell_type": "code",
+ "execution_count": null,
+ "id": "6879e471-a69f-4c74-9cec-a195cdead47c",
+ "metadata": {},
+ "outputs": [],
+ "source": []
+ }
+ ],
+ "metadata": {
+ "kernelspec": {
+ "display_name": "Python 3 (ipykernel)",
+ "language": "python",
+ "name": "python3"
+ },
+ "language_info": {
+ "codemirror_mode": {
+ "name": "ipython",
+ "version": 3
+ },
+ "file_extension": ".py",
+ "mimetype": "text/x-python",
+ "name": "python",
+ "nbconvert_exporter": "python",
+ "pygments_lexer": "ipython3",
+ "version": "3.8.13"
+ },
+ "vscode": {
+ "interpreter": {
+ "hash": "f9f85f796d01129d0dd105a088854619f454435301f6ffec2fea96ecbd9be4ac"
+ }
+ }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 5
+}
diff --git a/pyproject.toml b/pyproject.toml
index d9e63835e..79b90cb37 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -26,6 +26,7 @@ rich = "^12.5"
ray = {version = "2.1.0", extras = ["default"]}
kubernetes = ">= 25.3.0, < 27"
codeflare-torchx = "0.6.0.dev0"
+cryptography = "40.0.2"
[tool.poetry.group.docs]
optional = true
diff --git a/src/codeflare_sdk/cluster/cluster.py b/src/codeflare_sdk/cluster/cluster.py
index 31974b00f..0e0e73c06 100644
--- a/src/codeflare_sdk/cluster/cluster.py
+++ b/src/codeflare_sdk/cluster/cluster.py
@@ -84,6 +84,7 @@ def create_app_wrapper(self):
instascale = self.config.instascale
instance_types = self.config.machine_types
env = self.config.envs
+ local_interactive = self.config.local_interactive
return generate_appwrapper(
name=name,
namespace=namespace,
@@ -98,6 +99,7 @@ def create_app_wrapper(self):
instascale=instascale,
instance_types=instance_types,
env=env,
+ local_interactive=local_interactive,
)
# creates a new cluster with the provided or default spec
diff --git a/src/codeflare_sdk/cluster/config.py b/src/codeflare_sdk/cluster/config.py
index 25f129256..25392db75 100644
--- a/src/codeflare_sdk/cluster/config.py
+++ b/src/codeflare_sdk/cluster/config.py
@@ -48,3 +48,4 @@ class ClusterConfiguration:
instascale: bool = False
envs: dict = field(default_factory=dict)
image: str = "ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103"
+ local_interactive: bool = False
diff --git a/src/codeflare_sdk/templates/base-template.yaml b/src/codeflare_sdk/templates/base-template.yaml
index c99fd105d..f408df2e2 100644
--- a/src/codeflare_sdk/templates/base-template.yaml
+++ b/src/codeflare_sdk/templates/base-template.yaml
@@ -119,6 +119,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: "0"
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
name: ray-head
image: rayproject/ray:latest
imagePullPolicy: Always
@@ -142,6 +150,37 @@ spec:
cpu: 2
memory: "8G"
nvidia.com/gpu: 0
+ volumeMounts:
+ - name: ca-vol
+ mountPath: "/home/ray/workspace/ca"
+ readOnly: true
+ - name: server-cert
+ mountPath: "/home/ray/workspace/tls"
+ readOnly: true
+ initContainers:
+ - command:
+ - sh
+ - -c
+ - cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)\nDNS.5 = rayclient-deployment-name-$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).server-name">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext
+ image: rayproject/ray:2.5.0
+ name: create-cert
+ # securityContext:
+ # runAsUser: 1000
+ # runAsGroup: 1000
+ volumeMounts:
+ - name: ca-vol
+ mountPath: "/home/ray/workspace/ca"
+ readOnly: true
+ - name: server-cert
+ mountPath: "/home/ray/workspace/tls"
+ readOnly: false
+ volumes:
+ - name: ca-vol
+ secret:
+ secretName: ca-secret-deployment-name
+ optional: false
+ - name: server-cert
+ emptyDir: {}
workerGroupSpecs:
# the pod replicas in this group typed worker
- replicas: 3
@@ -187,6 +226,22 @@ spec:
- name: init-myservice
image: busybox:1.28
command: ['sh', '-c', "until nslookup $RAY_IP.$(cat /var/run/secrets/kubernetes.io/serviceaccount/namespace).svc.cluster.local; do echo waiting for myservice; sleep 2; done"]
+ - name: create-cert
+ image: rayproject/ray:2.5.0
+ command:
+ - sh
+ - -c
+ - cd /home/ray/workspace/tls && openssl req -nodes -newkey rsa:2048 -keyout server.key -out server.csr -subj '/CN=ray-head' && printf "authorityKeyIdentifier=keyid,issuer\nbasicConstraints=CA:FALSE\nsubjectAltName = @alt_names\n[alt_names]\nDNS.1 = 127.0.0.1\nDNS.2 = localhost\nDNS.3 = ${FQ_RAY_IP}\nDNS.4 = $(awk 'END{print $1}' /etc/hosts)">./domain.ext && cp /home/ray/workspace/ca/* . && openssl x509 -req -CA ca.crt -CAkey ca.key -in server.csr -out server.crt -days 365 -CAcreateserial -extfile domain.ext
+ # securityContext:
+ # runAsUser: 1000
+ # runAsGroup: 1000
+ volumeMounts:
+ - name: ca-vol
+ mountPath: "/home/ray/workspace/ca"
+ readOnly: true
+ - name: server-cert
+ mountPath: "/home/ray/workspace/tls"
+ readOnly: false
containers:
- name: machine-learning # must consist of lower case alphanumeric characters or '-', and must start and end with an alphanumeric character (e.g. 'my-name', or '123-abc'
image: rayproject/ray:latest
@@ -195,6 +250,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: "0"
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
# environment variables to set in the container.Optional.
# Refer to https://kubernetes.io/docs/tasks/inject-data-application/define-environment-variable-container/
lifecycle:
@@ -210,6 +273,20 @@ spec:
cpu: "2"
memory: "12G"
nvidia.com/gpu: "1"
+ volumeMounts:
+ - name: ca-vol
+ mountPath: "/home/ray/workspace/ca"
+ readOnly: true
+ - name: server-cert
+ mountPath: "/home/ray/workspace/tls"
+ readOnly: true
+ volumes:
+ - name: ca-vol
+ secret:
+ secretName: ca-secret-deployment-name
+ optional: false
+ - name: server-cert
+ emptyDir: {}
- replica: 1
generictemplate:
kind: Route
@@ -226,3 +303,33 @@ spec:
name: deployment-name-head-svc
port:
targetPort: dashboard
+ - replicas: 1
+ generictemplate:
+ apiVersion: route.openshift.io/v1
+ kind: Route
+ metadata:
+ name: rayclient-deployment-name
+ namespace: default
+ labels:
+ # allows me to return name of service that Ray operator creates
+ odh-ray-cluster-service: deployment-name-head-svc
+ spec:
+ port:
+ targetPort: client
+ tls:
+ termination: passthrough
+ to:
+ kind: Service
+ name: deployment-name-head-svc
+ - replicas: 1
+ generictemplate:
+ apiVersion: v1
+ data:
+ ca.crt: generated_crt
+ ca.key: generated_key
+ kind: Secret
+ metadata:
+ name: ca-secret-deployment-name
+ labels:
+ # allows me to return name of service that Ray operator creates
+ odh-ray-cluster-service: deployment-name-head-svc
diff --git a/src/codeflare_sdk/utils/generate_cert.py b/src/codeflare_sdk/utils/generate_cert.py
new file mode 100644
index 000000000..2d73621b8
--- /dev/null
+++ b/src/codeflare_sdk/utils/generate_cert.py
@@ -0,0 +1,161 @@
+# Copyright 2022 IBM, Red Hat
+#
+# Licensed under the Apache License, Version 2.0 (the "License");
+# you may not use this file except in compliance with the License.
+# You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+
+import base64
+import os
+from cryptography.hazmat.primitives import serialization, hashes
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography import x509
+from cryptography.x509.oid import NameOID
+import datetime
+from kubernetes import client, config
+
+
+def generate_ca_cert(days: int = 30):
+ # Generate base64 encoded ca.key and ca.cert
+ # Similar to:
+ # openssl req -x509 -nodes -newkey rsa:2048 -keyout ca.key -days 1826 -out ca.crt -subj '/CN=root-ca'
+ # base64 -i ca.crt -i ca.key
+
+ private_key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ )
+
+ key = base64.b64encode(
+ private_key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.PKCS8,
+ serialization.NoEncryption(),
+ )
+ ).decode("utf-8")
+
+ # Generate Certificate
+ one_day = datetime.timedelta(1, 0, 0)
+ public_key = private_key.public_key()
+ builder = (
+ x509.CertificateBuilder()
+ .subject_name(
+ x509.Name(
+ [
+ x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"),
+ ]
+ )
+ )
+ .issuer_name(
+ x509.Name(
+ [
+ x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"),
+ ]
+ )
+ )
+ .not_valid_before(datetime.datetime.today() - one_day)
+ .not_valid_after(datetime.datetime.today() + (one_day * days))
+ .serial_number(x509.random_serial_number())
+ .public_key(public_key)
+ )
+ certificate = base64.b64encode(
+ builder.sign(private_key=private_key, algorithm=hashes.SHA256()).public_bytes(
+ serialization.Encoding.PEM
+ )
+ ).decode("utf-8")
+ return key, certificate
+
+
+def generate_tls_cert(cluster_name, namespace, days=30):
+ # Create a folder tls-- and store three files: ca.crt, tls.crt, and tls.key
+ tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}")
+ if not os.path.exists(tls_dir):
+ os.makedirs(tls_dir)
+
+ # Similar to:
+ # oc get secret ca-secret- -o template='{{index .data "ca.key"}}'
+ # oc get secret ca-secret- -o template='{{index .data "ca.crt"}}'|base64 -d > ${TLSDIR}/ca.crt
+ config.load_kube_config()
+ v1 = client.CoreV1Api()
+ secret = v1.read_namespaced_secret(f"ca-secret-{cluster_name}", namespace).data
+ ca_cert = secret.get("ca.crt")
+ ca_key = secret.get("ca.key")
+
+ with open(os.path.join(tls_dir, "ca.crt"), "w") as f:
+ f.write(base64.b64decode(ca_cert).decode("utf-8"))
+
+ # Generate tls.key and signed tls.cert locally for ray client
+ # Similar to running these commands:
+ # openssl req -nodes -newkey rsa:2048 -keyout ${TLSDIR}/tls.key -out ${TLSDIR}/tls.csr -subj '/CN=local'
+ # cat <${TLSDIR}/domain.ext
+ # authorityKeyIdentifier=keyid,issuer
+ # basicConstraints=CA:FALSE
+ # subjectAltName = @alt_names
+ # [alt_names]
+ # DNS.1 = 127.0.0.1
+ # DNS.2 = localhost
+ # EOF
+ # openssl x509 -req -CA ${TLSDIR}/ca.crt -CAkey ${TLSDIR}/ca.key -in ${TLSDIR}/tls.csr -out ${TLSDIR}/tls.crt -days 365 -CAcreateserial -extfile ${TLSDIR}/domain.ext
+ key = rsa.generate_private_key(
+ public_exponent=65537,
+ key_size=2048,
+ )
+
+ tls_key = key.private_bytes(
+ serialization.Encoding.PEM,
+ serialization.PrivateFormat.PKCS8,
+ serialization.NoEncryption(),
+ )
+ with open(os.path.join(tls_dir, "tls.key"), "w") as f:
+ f.write(tls_key.decode("utf-8"))
+
+ one_day = datetime.timedelta(1, 0, 0)
+ tls_cert = (
+ x509.CertificateBuilder()
+ .issuer_name(
+ x509.Name(
+ [
+ x509.NameAttribute(NameOID.COMMON_NAME, "root-ca"),
+ ]
+ )
+ )
+ .subject_name(
+ x509.Name(
+ [
+ x509.NameAttribute(NameOID.COMMON_NAME, "local"),
+ ]
+ )
+ )
+ .public_key(key.public_key())
+ .not_valid_before(datetime.datetime.today() - one_day)
+ .not_valid_after(datetime.datetime.today() + (one_day * days))
+ .serial_number(x509.random_serial_number())
+ .add_extension(
+ x509.SubjectAlternativeName(
+ [x509.DNSName("localhost"), x509.DNSName("127.0.0.1")]
+ ),
+ False,
+ )
+ .sign(
+ serialization.load_pem_private_key(base64.b64decode(ca_key), None),
+ hashes.SHA256(),
+ )
+ )
+
+ with open(os.path.join(tls_dir, "tls.crt"), "w") as f:
+ f.write(tls_cert.public_bytes(serialization.Encoding.PEM).decode("utf-8"))
+
+
+def export_env(cluster_name, namespace):
+ tls_dir = os.path.join(os.getcwd(), f"tls-{cluster_name}-{namespace}")
+ os.environ["RAY_USE_TLS"] = "1"
+ os.environ["RAY_TLS_SERVER_CERT"] = os.path.join(tls_dir, "tls.crt")
+ os.environ["RAY_TLS_SERVER_KEY"] = os.path.join(tls_dir, "tls.key")
+ os.environ["RAY_TLS_CA_CERT"] = os.path.join(tls_dir, "ca.crt")
diff --git a/src/codeflare_sdk/utils/generate_yaml.py b/src/codeflare_sdk/utils/generate_yaml.py
index 36757a2d6..fffc4fa9a 100755
--- a/src/codeflare_sdk/utils/generate_yaml.py
+++ b/src/codeflare_sdk/utils/generate_yaml.py
@@ -21,6 +21,7 @@
import sys
import argparse
import uuid
+import openshift as oc
def read_template(template):
@@ -50,6 +51,16 @@ def update_dashboard_route(route_item, cluster_name, namespace):
spec["to"]["name"] = f"{cluster_name}-head-svc"
+# ToDo: refactor the update_x_route() functions
+def update_rayclient_route(route_item, cluster_name, namespace):
+ metadata = route_item.get("generictemplate", {}).get("metadata")
+ metadata["name"] = f"rayclient-{cluster_name}"
+ metadata["namespace"] = namespace
+ metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
+ spec = route_item.get("generictemplate", {}).get("spec")
+ spec["to"]["name"] = f"{cluster_name}-head-svc"
+
+
def update_names(yaml, item, appwrapper_name, cluster_name, namespace):
metadata = yaml.get("metadata")
metadata["name"] = appwrapper_name
@@ -191,6 +202,78 @@ def update_nodes(
update_resources(spec, min_cpu, max_cpu, min_memory, max_memory, gpu)
+def update_ca_secret(ca_secret_item, cluster_name, namespace):
+ from . import generate_cert
+
+ metadata = ca_secret_item.get("generictemplate", {}).get("metadata")
+ metadata["name"] = f"ca-secret-{cluster_name}"
+ metadata["namespace"] = namespace
+ metadata["labels"]["odh-ray-cluster-service"] = f"{cluster_name}-head-svc"
+ data = ca_secret_item.get("generictemplate", {}).get("data")
+ data["ca.key"], data["ca.crt"] = generate_cert.generate_ca_cert(365)
+
+
+def enable_local_interactive(resources, cluster_name, namespace):
+ rayclient_route_item = resources["resources"].get("GenericItems")[2]
+ ca_secret_item = resources["resources"].get("GenericItems")[3]
+ item = resources["resources"].get("GenericItems")[0]
+ update_rayclient_route(rayclient_route_item, cluster_name, namespace)
+ update_ca_secret(ca_secret_item, cluster_name, namespace)
+ # update_ca_secret_volumes
+ item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]["volumes"][0][
+ "secret"
+ ]["secretName"] = f"ca-secret-{cluster_name}"
+ item["generictemplate"]["spec"]["workerGroupSpecs"][0]["template"]["spec"][
+ "volumes"
+ ][0]["secret"]["secretName"] = f"ca-secret-{cluster_name}"
+ # update_tls_env
+ item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"]["containers"][
+ 0
+ ]["env"][1]["value"] = "1"
+ item["generictemplate"]["spec"]["workerGroupSpecs"][0]["template"]["spec"][
+ "containers"
+ ][0]["env"][1]["value"] = "1"
+ # update_init_container
+ command = item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"][
+ "initContainers"
+ ][0].get("command")[2]
+
+ command = command.replace("deployment-name", cluster_name)
+
+ server_name = (
+ oc.whoami("--show-server").split(":")[1].split("//")[1].replace("api", "apps")
+ )
+
+ command = command.replace("server-name", server_name)
+
+ item["generictemplate"]["spec"]["headGroupSpec"]["template"]["spec"][
+ "initContainers"
+ ][0].get("command")[2] = command
+
+
+def disable_raycluster_tls(resources):
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][
+ "template"
+ ]["spec"]["volumes"]
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][
+ "template"
+ ]["spec"]["containers"][0]["volumeMounts"]
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["headGroupSpec"][
+ "template"
+ ]["spec"]["initContainers"]
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][
+ "template"
+ ]["spec"]["volumes"]
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][
+ "template"
+ ]["spec"]["containers"][0]["volumeMounts"]
+ del resources["GenericItems"][0]["generictemplate"]["spec"]["workerGroupSpecs"][0][
+ "template"
+ ]["spec"]["initContainers"][1]
+ del resources["GenericItems"][3] # rayclient route
+ del resources["GenericItems"][2] # ca-secret
+
+
def write_user_appwrapper(user_yaml, output_file_name):
with open(output_file_name, "w") as outfile:
yaml.dump(user_yaml, outfile, default_flow_style=False)
@@ -211,6 +294,7 @@ def generate_appwrapper(
instascale: bool,
instance_types: list,
env,
+ local_interactive: bool,
):
user_yaml = read_template(template)
appwrapper_name, cluster_name = gen_names(name)
@@ -236,6 +320,10 @@ def generate_appwrapper(
env,
)
update_dashboard_route(route_item, cluster_name, namespace)
+ if local_interactive:
+ enable_local_interactive(resources, cluster_name, namespace)
+ else:
+ disable_raycluster_tls(resources["resources"])
outfile = appwrapper_name + ".yaml"
write_user_appwrapper(user_yaml, outfile)
return outfile
@@ -315,6 +403,12 @@ def main(): # pragma: no cover
default="default",
help="Set the kubernetes namespace you want to deploy your cluster to. Default. If left blank, uses the 'default' namespace",
)
+ parser.add_argument(
+ "--local-interactive",
+ required=False,
+ default=False,
+ help="Enable local interactive mode",
+ )
args = parser.parse_args()
name = args.name
@@ -329,6 +423,7 @@ def main(): # pragma: no cover
instascale = args.instascale
instance_types = args.instance_types
namespace = args.namespace
+ local_interactive = args.local_interactive
env = {}
outfile = generate_appwrapper(
@@ -344,6 +439,7 @@ def main(): # pragma: no cover
image,
instascale,
instance_types,
+ local_interactive,
env,
)
return outfile
diff --git a/tests/test-case-cmd.yaml b/tests/test-case-cmd.yaml
index 450ec9668..d82096e8f 100644
--- a/tests/test-case-cmd.yaml
+++ b/tests/test-case-cmd.yaml
@@ -62,6 +62,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: '0'
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
image: rayproject/ray:latest
imagePullPolicy: Always
lifecycle:
@@ -110,6 +118,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: '0'
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
image: rayproject/ray:latest
lifecycle:
preStop:
diff --git a/tests/test-case.yaml b/tests/test-case.yaml
index 133a22229..0d85428df 100644
--- a/tests/test-case.yaml
+++ b/tests/test-case.yaml
@@ -73,6 +73,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: '0'
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
image: ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103
imagePullPolicy: Always
lifecycle:
@@ -130,6 +138,14 @@ spec:
valueFrom:
fieldRef:
fieldPath: status.podIP
+ - name: RAY_USE_TLS
+ value: '0'
+ - name: RAY_TLS_SERVER_CERT
+ value: /home/ray/workspace/tls/server.crt
+ - name: RAY_TLS_SERVER_KEY
+ value: /home/ray/workspace/tls/server.key
+ - name: RAY_TLS_CA_CERT
+ value: /home/ray/workspace/tls/ca.crt
image: ghcr.io/foundation-model-stack/base:ray2.1.0-py38-gpu-pytorch1.12.0cu116-20221213-193103
lifecycle:
preStop:
diff --git a/tests/unit_test.py b/tests/unit_test.py
index 47c0d43a8..ead3521c8 100644
--- a/tests/unit_test.py
+++ b/tests/unit_test.py
@@ -21,6 +21,7 @@
parent = Path(__file__).resolve().parents[1]
sys.path.append(str(parent) + "/src")
+from kubernetes import client
from codeflare_sdk.cluster.awload import AWManager
from codeflare_sdk.cluster.cluster import (
Cluster,
@@ -54,6 +55,12 @@
DDPJob,
torchx_runner,
)
+from codeflare_sdk.utils.generate_cert import (
+ generate_ca_cert,
+ generate_tls_cert,
+ export_env,
+)
+
import openshift
from openshift import OpenShiftPythonException
from openshift.selector import Selector
@@ -1980,6 +1987,88 @@ def test_AWManager_submit_remove(mocker, capsys):
assert testaw.submitted == False
+from cryptography.x509 import load_pem_x509_certificate
+import base64
+from cryptography.hazmat.primitives.serialization import (
+ load_pem_private_key,
+ Encoding,
+ PublicFormat,
+)
+
+
+def test_generate_ca_cert():
+ """
+ test the function codeflare_sdk.utils.generate_ca_cert generates the correct outputs
+ """
+ key, certificate = generate_ca_cert()
+ cert = load_pem_x509_certificate(base64.b64decode(certificate))
+ private_pub_key_bytes = (
+ load_pem_private_key(base64.b64decode(key), password=None)
+ .public_key()
+ .public_bytes(Encoding.PEM, PublicFormat.SubjectPublicKeyInfo)
+ )
+ cert_pub_key_bytes = cert.public_key().public_bytes(
+ Encoding.PEM, PublicFormat.SubjectPublicKeyInfo
+ )
+ assert type(key) == str
+ assert type(certificate) == str
+ # Veirfy ca.cert is self signed
+ assert cert.verify_directly_issued_by(cert) == None
+ # Verify cert has the public key bytes from the private key
+ assert cert_pub_key_bytes == private_pub_key_bytes
+
+
+def secret_ca_retreival(secret_name, namespace):
+ ca_private_key_bytes, ca_cert = generate_ca_cert()
+ data = {"ca.crt": ca_cert, "ca.key": ca_private_key_bytes}
+ assert secret_name == "ca-secret-cluster"
+ assert namespace == "namespace"
+ return client.models.V1Secret(data=data)
+
+
+def test_generate_tls_cert(mocker):
+ """
+ test the function codeflare_sdk.utils.generate_ca_cert generates the correct outputs
+ """
+ mocker.patch("kubernetes.config.load_kube_config", return_value="ignore")
+ mocker.patch(
+ "kubernetes.client.CoreV1Api.read_namespaced_secret",
+ side_effect=secret_ca_retreival,
+ )
+
+ generate_tls_cert("cluster", "namespace")
+ assert os.path.exists("tls-cluster-namespace")
+ assert os.path.exists(os.path.join("tls-cluster-namespace", "ca.crt"))
+ assert os.path.exists(os.path.join("tls-cluster-namespace", "tls.crt"))
+ assert os.path.exists(os.path.join("tls-cluster-namespace", "tls.key"))
+
+ # verify the that the signed tls.crt is issued by the ca_cert (root cert)
+ with open(os.path.join("tls-cluster-namespace", "tls.crt"), "r") as f:
+ tls_cert = load_pem_x509_certificate(f.read().encode("utf-8"))
+ with open(os.path.join("tls-cluster-namespace", "ca.crt"), "r") as f:
+ root_cert = load_pem_x509_certificate(f.read().encode("utf-8"))
+ assert tls_cert.verify_directly_issued_by(root_cert) == None
+
+
+def test_export_env():
+ """
+ test the function codeflare_sdk.utils.export_ev generates the correct outputs
+ """
+ tls_dir = "cluster"
+ ns = "namespace"
+ export_env(tls_dir, ns)
+ assert os.environ["RAY_USE_TLS"] == "1"
+ assert os.environ["RAY_TLS_SERVER_CERT"] == os.path.join(
+ os.getcwd(), f"tls-{tls_dir}-{ns}", "tls.crt"
+ )
+ assert os.environ["RAY_TLS_SERVER_KEY"] == os.path.join(
+ os.getcwd(), f"tls-{tls_dir}-{ns}", "tls.key"
+ )
+ assert os.environ["RAY_TLS_CA_CERT"] == os.path.join(
+ os.getcwd(), f"tls-{tls_dir}-{ns}", "ca.crt"
+ )
+
+
# Make sure to keep this function and the following function at the end of the file
def test_cmd_line_generation():
os.system(