From 9d451380869271ca1ef2f847e4b0dd07e00278e2 Mon Sep 17 00:00:00 2001 From: Richard Liu <39319471+richardsliu@users.noreply.github.com> Date: Tue, 2 Apr 2024 00:51:37 +0000 Subject: [PATCH] Refactor Rag notebook (#504) * move mysql stuff to jupyter * new notebook * fix notebook * fix notebook, add markdown * use bulk insert * revert * change persist data * terraform fmt * remove sql params from notebook * default empty values * rename * parameterize notebook image * remove pip installs from notebook * use custom notebook image * terraform fmt * replace jupyter notebook tag * add notebook version to jupyterhub app * merge cells * add dummy value for secret volume * fix old notebook --- applications/jupyter/main.tf | 2 + .../rag-kaggle-ray-sql-interactive.ipynb | 406 ++++++++++++++++++ .../rag-kaggle-ray-sql-latest.ipynb | 16 +- applications/rag/main.tf | 7 + .../config-selfauth-autopilot.yaml | 15 +- .../jupyter_config/config-selfauth.yaml | 12 +- modules/jupyter/main.tf | 62 +-- modules/jupyter/variables.tf | 33 +- 8 files changed, 516 insertions(+), 37 deletions(-) create mode 100644 applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb diff --git a/applications/jupyter/main.tf b/applications/jupyter/main.tf index 3a9dd89a1..fc7b3d791 100644 --- a/applications/jupyter/main.tf +++ b/applications/jupyter/main.tf @@ -149,6 +149,8 @@ module "jupyterhub" { workload_identity_service_account = local.workload_identity_service_account gcs_bucket = var.gcs_bucket autopilot_cluster = local.enable_autopilot + notebook_image = "jupyter/tensorflow-notebook" + notebook_image_tag = "python-3.10" # IAP Auth parameters add_auth = var.add_auth diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb new file mode 100644 index 000000000..2b80e437e --- /dev/null +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb @@ -0,0 +1,406 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "5574f366-58e9-408b-aea4-1bf5b3351e4c", + "metadata": {}, + "source": [ + "# RAG-on-GKE Application\n", + "\n", + "This is a Python notebook for generating the vector embeddings used by the RAG on GKE application. For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).\n", + "\n", + "\n", + "## Setup Kaggle Credentials\n", + "\n", + "First we will setup your Kaggle credentials and use the Kaggle CLI to download the NetFlix shows dataset to the GCS bucket. Replace the following with your own settings from the Kaggle web page. Navigate to https://www.kaggle.com/settings/account and generate an API token to be used to setup the env variable. See https://www.kaggle.com/docs/api#authentication how to create one." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ffee2bec-804f-4e22-9ba0-8b1db5a5d7ec", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "os.environ['KAGGLE_USERNAME'] = \"\"\n", + "os.environ['KAGGLE_KEY'] = \"\"\n", + "\n", + "# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n", + "!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n", + "!mkdir /data/netflix-shows -p\n", + "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" + ] + }, + { + "cell_type": "markdown", + "id": "c7ff518d-f4d2-481b-b408-2c2507565611", + "metadata": {}, + "source": [ + "## Creating the Database Connection\n", + "\n", + "Let's now set up a connection to your CloudSQL database:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "aff789e7-a32d-4dd7-afb8-d3a22c8f3cec", + "metadata": {}, + "outputs": [], + "source": [ + "import os\n", + "import uuid\n", + "import ray\n", + "from langchain.document_loaders import ArxivLoader\n", + "from langchain.text_splitter import RecursiveCharacterTextSplitter\n", + "from sentence_transformers import SentenceTransformer\n", + "from typing import List\n", + "import torch\n", + "from datasets import load_dataset_builder, load_dataset, Dataset\n", + "from huggingface_hub import snapshot_download\n", + "from google.cloud.sql.connector import Connector, IPTypes\n", + "import sqlalchemy\n", + "\n", + "# initialize parameters\n", + "\n", + "INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n", + "print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n", + "DB_NAME = \"pgvector-database\"\n", + "\n", + "db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n", + "DB_USER = db_username_file.read()\n", + "db_username_file.close()\n", + "\n", + "db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n", + "DB_PASS = db_password_file.read()\n", + "db_password_file.close()\n", + "\n", + "# initialize Connector object\n", + "connector = Connector()\n", + "\n", + "# function to return the database connection object\n", + "def getconn():\n", + " conn = connector.connect(\n", + " INSTANCE_CONNECTION_NAME,\n", + " \"pg8000\",\n", + " user=DB_USER,\n", + " password=DB_PASS,\n", + " db=DB_NAME,\n", + " ip_type=IPTypes.PRIVATE\n", + " )\n", + " return conn\n", + "\n", + "# create connection pool with 'creator' argument to our connection object function\n", + "pool = sqlalchemy.create_engine(\n", + " \"postgresql+pg8000://\",\n", + " creator=getconn,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "2156a6bd-1100-46c2-8ad6-22a923b3d6ac", + "metadata": {}, + "source": [ + "Next we'll setup some parameters for the dataset processing steps:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b24267d7-3fad-47f1-8aa7-2fbe21fe8fa1", + "metadata": {}, + "outputs": [], + "source": [ + "SHARED_DATA_BASEPATH='/data/rag/st'\n", + "SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n", + "SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n", + "SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n", + "SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n", + "\n", + "# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n", + "SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n", + "REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n", + "\n", + "BATCH_SIZE = 100\n", + "CHUNK_SIZE = 1000 # text chunk sizes which will be converted to vector embeddings\n", + "CHUNK_OVERLAP = 10\n", + "TABLE_NAME = 'netflix_reviews_db' # CloudSQL table name\n", + "DIMENSION = 384 # Embeddings size\n", + "ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function" + ] + }, + { + "cell_type": "markdown", + "id": "3dc5bc85-dc3b-4622-99a2-f9fc269e753b", + "metadata": {}, + "source": [ + "Now we will download the sentence transformer model to our GCS bucket:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b7a676be-56c6-4c76-8041-9ad05361dd3b", + "metadata": {}, + "outputs": [], + "source": [ + "# prepare the persistent shared directory to store artifacts needed for the ray workers\n", + "os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n", + "\n", + "# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n", + "snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)" + ] + }, + { + "cell_type": "markdown", + "id": "f7304035-21a4-4017-bce9-aba7e9f81c90", + "metadata": {}, + "source": [ + "## Generating Vector Embeddings\n", + "\n", + "We are ready to begin. Let's first create some code for generating the vector embeddings:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5bbb3750-7cd5-439f-a767-617cd5948e27", + "metadata": {}, + "outputs": [], + "source": [ + "class Embed:\n", + " def __init__(self):\n", + " print(\"torch cuda version\", torch.version.cuda)\n", + " device=\"cpu\"\n", + " if torch.cuda.is_available():\n", + " print(\"device cuda found\")\n", + " device=\"cuda\"\n", + "\n", + " print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n", + " self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n", + " self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n", + "\n", + " def __call__(self, text_batch: List[str]):\n", + " text = text_batch[\"item\"]\n", + " # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n", + " chunks = []\n", + " for data in text:\n", + " splits = self.splitter.split_text(data)\n", + " # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n", + " chunks.extend(splits)\n", + "\n", + " embeddings = self.transformer.encode(\n", + " chunks,\n", + " batch_size=BATCH_SIZE\n", + " ).tolist()\n", + " print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n", + " return {'results':list(zip(chunks, embeddings))}" + ] + }, + { + "cell_type": "markdown", + "id": "7263b9db-9504-4177-acd6-5e1aba2ee332", + "metadata": {}, + "source": [ + "Next we will initialize a Ray cluster to execute the remote task:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "ba9551ec-883e-4bde-9f12-f663aedc12e5", + "metadata": {}, + "outputs": [], + "source": [ + "import ray\n", + "\n", + "ray.init(\n", + " address=\"ray://ray-cluster-kuberay-head-svc:10001\",\n", + " runtime_env={\n", + " \"pip\": [ \n", + " \"langchain==0.1.10\",\n", + " \"transformers==4.38.1\",\n", + " \"sentence-transformers==2.5.1\",\n", + " \"pyarrow\",\n", + " \"datasets==2.18.0\",\n", + " \"torch==2.0.1\",\n", + " \"huggingface_hub==0.21.3\",\n", + " ]\n", + " }\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "9589048c-a0aa-4740-acde-5289cd4076f7", + "metadata": {}, + "source": [ + "Generate vector embeddings using our Embed class above:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a392975f-3743-4b2c-8673-087b5633637e", + "metadata": {}, + "outputs": [], + "source": [ + "# Process the dataset first, wrap the csv file contents into a Ray dataset\n", + "ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)\n", + "print(ray_ds.schema)\n", + "\n", + "# Distributed flat map to extract the raw text fields.\n", + "ds_batch = ray_ds.flat_map(lambda row: [{\n", + " 'item': \"This is a \" + str(row[\"type\"]) + \" in \" + str(row[\"country\"]) + \" called \" + str(row[\"title\"]) + \n", + " \" added at \" + str(row[\"date_added\"]) + \" whose director is \" + str(row[\"director\"]) + \n", + " \" and with cast: \" + str(row[\"cast\"]) + \" released at \" + str(row[\"release_year\"]) + \n", + " \". Its rating is: \" + str(row['rating']) + \". Its duration is \" + str(row[\"duration\"]) + \n", + " \". Its description is \" + str(row['description']) + \".\"\n", + "}])\n", + "print(ds_batch.schema)\n", + "\n", + "# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n", + "ds_embed = ds_batch.map_batches(\n", + " Embed,\n", + " compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n", + " batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n", + " num_gpus=1, # 1 GPU for each actor.\n", + " # num_cpus=1,\n", + ")" + ] + }, + { + "cell_type": "markdown", + "id": "4697ac28-9815-409c-95ec-6ecdb862bb74", + "metadata": {}, + "source": [ + "Retrieve the result data from Ray remote workers:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e0edbba2-8977-4afd-aaa2-2e6e3a298169", + "metadata": {}, + "outputs": [], + "source": [ + "@ray.remote\n", + "def ray_data_task(ds_embed):\n", + " results = []\n", + " for row in ds_embed.iter_rows():\n", + " data_text = row[\"results\"][0][:65535]\n", + " data_emb = row[\"results\"][1]\n", + "\n", + " results.append((data_text, data_emb))\n", + " \n", + " return results\n", + " \n", + "results = ray.get(ray_data_task.remote(ds_embed))" + ] + }, + { + "cell_type": "markdown", + "id": "5652832e-025d-4745-9fef-96615eea07e4", + "metadata": {}, + "source": [ + "## Writing Results Back to MySQL\n", + "\n", + "Now that we have our vector embeddings, we can write our results back to the MySQL database:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "07eb5ec7-c4f7-4312-b0ce-ea07160bef92", + "metadata": {}, + "outputs": [], + "source": [ + "from sqlalchemy.ext.declarative import declarative_base\n", + "from sqlalchemy import Column, String, Text, text\n", + "from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n", + "from pgvector.sqlalchemy import Vector\n", + "\n", + "\n", + "Base = declarative_base()\n", + "DBSession = scoped_session(sessionmaker())\n", + "\n", + "class TextEmbedding(Base):\n", + " __tablename__ = TABLE_NAME\n", + " id = Column(String(255), primary_key=True)\n", + " text = Column(Text)\n", + " text_embedding = mapped_column(Vector(384))\n", + "\n", + "with pool.connect() as conn:\n", + " conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n", + " conn.commit() \n", + " \n", + "DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)\n", + "Base.metadata.drop_all(pool)\n", + "Base.metadata.create_all(pool)\n", + "\n", + "rows = []\n", + "for r in results:\n", + " id = uuid.uuid4() \n", + " rows.append(TextEmbedding(id=id, text=r[0], text_embedding=r[1]))\n", + "\n", + "DBSession.bulk_save_objects(rows)\n", + "DBSession.commit()" + ] + }, + { + "cell_type": "markdown", + "id": "b4b19b1c-a83b-4a83-94a9-5edf5ae7016a", + "metadata": {}, + "source": [ + "Finally let's verify that our embeddings got stored in the database correctly:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "4cff4bbc-574d-4cc2-8c87-d0ff6d351626", + "metadata": {}, + "outputs": [], + "source": [ + "with pool.connect() as db_conn:\n", + " # verify results\n", + " transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n", + " query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n", + " query_emb = transformer.encode(query_text).tolist()\n", + " query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n", + " query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n", + " db_conn.commit()\n", + " \n", + " print(\"print query_results, the 1st one is the hit\")\n", + " for row in query_results:\n", + " print(row)\n", + "\n", + "# cleanup connector object\n", + "connector.close()" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb b/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb index 7775bacb9..726014d6d 100644 --- a/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb +++ b/applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb @@ -1,5 +1,15 @@ { "cells": [ + { + "cell_type": "markdown", + "id": "afb6fda4-ffde-4831-88a5-ae41144492b2", + "metadata": {}, + "source": [ + "# RAG-on-GKE Application\n", + "\n", + "This is a Python notebook for generating the vector embeddings used by the RAG on GKE application. For full information, please checkout the GitHub documentation [here](https://github.com/GoogleCloudPlatform/ai-on-gke/blob/main/applications/rag/README.md).\n" + ] + }, { "cell_type": "code", "execution_count": null, @@ -10,7 +20,7 @@ "# Replace these with your settings\n", "# Navigate to https://www.kaggle.com/settings/account and generate an API token to be used to setup the env variable. See https://www.kaggle.com/docs/api#authentication how to create one.\n", "KAGGLE_USERNAME = \"\"\n", - "KAGGLE_KEY = \"\"\n" + "KAGGLE_KEY = \"\"" ] }, { @@ -36,8 +46,8 @@ "\n", "# Download the zip file to local storage and then extract the desired contents directly to the GKE GCS CSI mounted bucket. The bucket is mounted at the \"/persist-data\" path in the jupyter pod.\n", "!kaggle datasets download -d shivamb/netflix-shows -p ~/data --force\n", - "!mkdir /persist-data/netflix-shows -p\n", - "!unzip -o ~/data/netflix-shows.zip -d /persist-data/netflix-shows" + "!mkdir /data/netflix-shows -p\n", + "!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows" ] }, { diff --git a/applications/rag/main.tf b/applications/rag/main.tf index 235ce6694..ccde3e6c7 100644 --- a/applications/rag/main.tf +++ b/applications/rag/main.tf @@ -192,6 +192,13 @@ module "jupyterhub" { autopilot_cluster = local.enable_autopilot workload_identity_service_account = local.jupyter_service_account + notebook_image = "us-central1-docker.pkg.dev/ai-on-gke/rag-on-gke/jupyter-notebook-image" + notebook_image_tag = "v1.1-rag" + + db_secret_name = module.cloudsql.db_secret_name + cloudsql_instance_name = local.cloudsql_instance + db_region = local.cloudsql_instance_region + # IAP Auth parameters create_brand = var.create_brand support_email = var.support_email diff --git a/modules/jupyter/jupyter_config/config-selfauth-autopilot.yaml b/modules/jupyter/jupyter_config/config-selfauth-autopilot.yaml index e8851f566..c9910c17d 100644 --- a/modules/jupyter/jupyter_config/config-selfauth-autopilot.yaml +++ b/modules/jupyter/jupyter_config/config-selfauth-autopilot.yaml @@ -76,8 +76,8 @@ singleuser: limit: 16G guarantee: 16G cpu: - limit: 4 - guarantee: 4 + limit: 8 + guarantee: 8 extraResource: limits: ephemeral-storage: ${ephemeral_storage} @@ -88,6 +88,7 @@ singleuser: extraEnv: # Used for GCSFuse to set the ephemeral storage as the home directory. If not set, it will show a permission error on the pod log when using GCSFuse. JUPYTER_ALLOW_INSECURE_WRITES: "true" + CLOUDSQL_INSTANCE_CONNECTION_NAME: ${cloudsql_instance_connection_name} extraLabels: ${indent(4, chomp(jsonencode(additional_labels)))} image: @@ -111,9 +112,17 @@ singleuser: volumeAttributes: bucketName: ${gcs_bucket} mountOptions: "implicit-dirs,uid=1000,gid=100" + - name: secret-volume + secret: + secretName: ${secret_name} + optional: true + extraVolumeMounts: - name: data-vol - mountPath: /persist-data + mountPath: /data + - name: secret-volume + mountPath: /etc/secret-volume + readOnly: true profileList: - display_name: "CPU (C3)" description: "Creates CPU (C3) VMs as the compute for notebook execution." diff --git a/modules/jupyter/jupyter_config/config-selfauth.yaml b/modules/jupyter/jupyter_config/config-selfauth.yaml index d69ac5747..4fdb055dc 100644 --- a/modules/jupyter/jupyter_config/config-selfauth.yaml +++ b/modules/jupyter/jupyter_config/config-selfauth.yaml @@ -75,7 +75,7 @@ singleuser: guarantee: 8G cpu: limit: 8 - guarantee: 4 + guarantee: 8 extraResource: limits: ephemeral-storage: ${ephemeral_storage} @@ -86,6 +86,7 @@ singleuser: extraEnv: # Used for GCSFuse to set the ephemeral storage as the home directory. If not set, it will show a permission error on the pod log when using GCSFuse. JUPYTER_ALLOW_INSECURE_WRITES: "true" + CLOUDSQL_INSTANCE_CONNECTION_NAME: ${cloudsql_instance_connection_name} extraLabels: ${indent(4, chomp(jsonencode(additional_labels)))} image: @@ -109,9 +110,16 @@ singleuser: volumeAttributes: bucketName: ${gcs_bucket} mountOptions: "implicit-dirs,uid=1000,gid=100" + - name: secret-volume + secret: + secretName: ${secret_name} + optional: true extraVolumeMounts: - name: data-vol - mountPath: /persist-data + mountPath: /data + - name: secret-volume + mountPath: /etc/secret-volume + readOnly: true # More info on kubespawner overrides: https://jupyterhub-kubespawner.readthedocs.io/en/latest/spawner.html#kubespawner.KubeSpawner # profile example: # - display_name: "Learning Data Science" diff --git a/modules/jupyter/main.tf b/modules/jupyter/main.tf index faa92d74f..9adfb41aa 100644 --- a/modules/jupyter/main.tf +++ b/modules/jupyter/main.tf @@ -17,6 +17,7 @@ data "google_project" "project" { } locals { + cloudsql_instance_connection_name = var.cloudsql_instance_name != "" ? format("%s:%s:%s", var.project_id, var.db_region, var.cloudsql_instance_name) : "" additional_labels = tomap({ for item in var.additional_labels : split("=", item)[0] => split("=", item)[1] @@ -109,36 +110,41 @@ resource "helm_release" "jupyterhub" { timeout = 600 values = var.autopilot_cluster ? [templatefile("${path.module}/jupyter_config/config-selfauth-autopilot.yaml", { - password = var.add_auth ? "dummy" : random_password.generated_password[0].result - project_id = var.project_id - project_number = data.google_project.project.number - namespace = var.namespace - additional_labels = local.additional_labels - backend_config = var.k8s_backend_config_name - service_name = var.k8s_backend_service_name - authenticator_class = var.add_auth ? "'gcpiapjwtauthenticator.GCPIAPAuthenticator'" : "dummy" - service_type = var.add_auth ? "NodePort" : "ClusterIP" - gcs_bucket = var.gcs_bucket - k8s_service_account = var.workload_identity_service_account - ephemeral_storage = var.ephemeral_storage - notebook_image = "jupyter/tensorflow-notebook" - notebook_image_tag = "python-3.10" + password = var.add_auth ? "dummy" : random_password.generated_password[0].result + project_id = var.project_id + project_number = data.google_project.project.number + namespace = var.namespace + additional_labels = local.additional_labels + backend_config = var.k8s_backend_config_name + service_name = var.k8s_backend_service_name + authenticator_class = var.add_auth ? "'gcpiapjwtauthenticator.GCPIAPAuthenticator'" : "dummy" + service_type = var.add_auth ? "NodePort" : "ClusterIP" + gcs_bucket = var.gcs_bucket + k8s_service_account = var.workload_identity_service_account + ephemeral_storage = var.ephemeral_storage + secret_name = var.db_secret_name + cloudsql_instance_connection_name = local.cloudsql_instance_connection_name + + notebook_image = var.notebook_image + notebook_image_tag = var.notebook_image_tag }) ] : [templatefile("${path.module}/jupyter_config/config-selfauth.yaml", { - password = var.add_auth ? "dummy" : random_password.generated_password[0].result - project_id = var.project_id - project_number = data.google_project.project.number - namespace = var.namespace - additional_labels = local.additional_labels - backend_config = var.k8s_backend_config_name - service_name = var.k8s_backend_service_name - authenticator_class = var.add_auth ? "'gcpiapjwtauthenticator.GCPIAPAuthenticator'" : "dummy" - service_type = var.add_auth ? "NodePort" : "ClusterIP" - gcs_bucket = var.gcs_bucket - k8s_service_account = var.workload_identity_service_account - ephemeral_storage = var.ephemeral_storage - notebook_image = "jupyter/tensorflow-notebook" - notebook_image_tag = "python-3.10" + password = var.add_auth ? "dummy" : random_password.generated_password[0].result + project_id = var.project_id + project_number = data.google_project.project.number + namespace = var.namespace + additional_labels = local.additional_labels + backend_config = var.k8s_backend_config_name + service_name = var.k8s_backend_service_name + authenticator_class = var.add_auth ? "'gcpiapjwtauthenticator.GCPIAPAuthenticator'" : "dummy" + service_type = var.add_auth ? "NodePort" : "ClusterIP" + gcs_bucket = var.gcs_bucket + k8s_service_account = var.workload_identity_service_account + ephemeral_storage = var.ephemeral_storage + secret_name = var.db_secret_name + cloudsql_instance_connection_name = local.cloudsql_instance_connection_name + notebook_image = var.notebook_image + notebook_image_tag = var.notebook_image_tag }) ] depends_on = [module.jupyterhub-workload-identity] diff --git a/modules/jupyter/variables.tf b/modules/jupyter/variables.tf index d03c0262d..cd2ebaca5 100644 --- a/modules/jupyter/variables.tf +++ b/modules/jupyter/variables.tf @@ -17,6 +17,18 @@ variable "namespace" { description = "Kubernetes namespace where resources are deployed" } +variable "notebook_image" { + type = string + description = "Jupyter notebook image name" + default = "jupyter/tensorflow-notebook" +} + +variable "notebook_image_tag" { + type = string + description = "Jupyter notebook image tag" + default = "python-3.10" +} + variable "members_allowlist" { type = list(string) default = [] @@ -134,4 +146,23 @@ variable "ephemeral_storage" { variable "autopilot_cluster" { type = bool -} \ No newline at end of file +} + +variable "db_region" { + type = string + description = "Cloud SQL instance region" + default = "" +} + +variable "db_secret_name" { + type = string + description = "CloudSQL user credentials" + default = "dummy_value" +} + +variable "cloudsql_instance_name" { + type = string + description = "Cloud SQL instance name" + default = "" +} +