Skip to content

Commit

Permalink
fix: test rag
Browse files Browse the repository at this point in the history
Change-Id: I21760d28521c0a95fd953f3b08fb0fbb50ca4be8
  • Loading branch information
Gen Lu committed Sep 18, 2024
1 parent 6e48fd5 commit 6648640
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 542 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@
"id": "c7ff518d-f4d2-481b-b408-2c2507565611",
"metadata": {},
"source": [
"## Creating the Database Connection\n",
"## Download the Data\n",
"\n",
"Let's now set up a connection to your CloudSQL database:"
"Let's now import required modules:"
]
},
{
Expand All @@ -60,42 +60,7 @@
"from datasets import load_dataset_builder, load_dataset, Dataset\n",
"from huggingface_hub import snapshot_download\n",
"from google.cloud.sql.connector import Connector, IPTypes\n",
"import sqlalchemy\n",
"\n",
"# initialize parameters\n",
"\n",
"INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n",
"print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n",
"DB_NAME = \"pgvector-database\"\n",
"\n",
"db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
"DB_USER = db_username_file.read()\n",
"db_username_file.close()\n",
"\n",
"db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
"DB_PASS = db_password_file.read()\n",
"db_password_file.close()\n",
"\n",
"# initialize Connector object\n",
"connector = Connector()\n",
"\n",
"# function to return the database connection object\n",
"def getconn():\n",
" conn = connector.connect(\n",
" INSTANCE_CONNECTION_NAME,\n",
" \"pg8000\",\n",
" user=DB_USER,\n",
" password=DB_PASS,\n",
" db=DB_NAME,\n",
" ip_type=IPTypes.PRIVATE\n",
" )\n",
" return conn\n",
"\n",
"# create connection pool with 'creator' argument to our connection object function\n",
"pool = sqlalchemy.create_engine(\n",
" \"postgresql+pg8000://\",\n",
" creator=getconn,\n",
")"
"import sqlalchemy"
]
},
{
Expand Down Expand Up @@ -322,6 +287,40 @@
"from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n",
"from pgvector.sqlalchemy import Vector\n",
"\n",
"# initialize parameters\n",
"\n",
"INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n",
"print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n",
"DB_NAME = \"pgvector-database\"\n",
"\n",
"db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
"DB_USER = db_username_file.read()\n",
"db_username_file.close()\n",
"\n",
"db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
"DB_PASS = db_password_file.read()\n",
"db_password_file.close()\n",
"\n",
"# initialize Connector object\n",
"connector = Connector()\n",
"\n",
"# function to return the database connection object\n",
"def getconn():\n",
" conn = connector.connect(\n",
" INSTANCE_CONNECTION_NAME,\n",
" \"pg8000\",\n",
" user=DB_USER,\n",
" password=DB_PASS,\n",
" db=DB_NAME,\n",
" ip_type=IPTypes.PRIVATE\n",
" )\n",
" return conn\n",
"\n",
"# create connection pool with 'creator' argument to our connection object function\n",
"pool = sqlalchemy.create_engine(\n",
" \"postgresql+pg8000://\",\n",
" creator=getconn,\n",
")\n",
"\n",
"Base = declarative_base()\n",
"DBSession = scoped_session(sessionmaker())\n",
Expand Down
131 changes: 3 additions & 128 deletions applications/rag/main.tf
Original file line number Diff line number Diff line change
Expand Up @@ -154,28 +154,17 @@ module "namespace" {

module "gcs" {
source = "../../modules/gcs"
count = var.create_gcs_bucket ? 1 : 0
count = 0
project_id = var.project_id
bucket_name = var.gcs_bucket
}

module "cloudsql" {
source = "../../modules/cloudsql"
providers = { kubernetes = kubernetes.rag }
project_id = var.project_id
instance_name = local.cloudsql_instance
namespace = local.kubernetes_namespace
region = local.cloudsql_instance_region
network_name = local.network_name
depends_on = [module.namespace]
bucket_name = "gke-aieco-rag-a236619-2ec62154"
}

module "jupyterhub" {
source = "../../modules/jupyter"
providers = { helm = helm.rag, kubernetes = kubernetes.rag }
namespace = local.kubernetes_namespace
project_id = var.project_id
gcs_bucket = var.gcs_bucket
gcs_bucket = "gke-aieco-rag-a236619-2ec62154"
add_auth = var.jupyter_add_auth
additional_labels = var.additional_labels

Expand Down Expand Up @@ -205,117 +194,3 @@ module "jupyterhub" {

depends_on = [module.namespace, module.gcs]
}

module "kuberay-workload-identity" {
providers = { kubernetes = kubernetes.rag }
source = "terraform-google-modules/kubernetes-engine/google//modules/workload-identity"
version = "30.0.0" # Pinning to a previous version as current version (30.1.0) showed inconsitent behaviour with workload identity service accounts
use_existing_gcp_sa = !var.create_ray_service_account
name = local.ray_service_account
namespace = local.kubernetes_namespace
project_id = var.project_id
roles = ["roles/cloudsql.client", "roles/monitoring.viewer"]
automount_service_account_token = true
depends_on = [module.namespace]
}

module "kuberay-monitoring" {
source = "../../modules/kuberay-monitoring"
providers = { helm = helm.rag, kubernetes = kubernetes.rag }
project_id = var.project_id
autopilot_cluster = local.enable_autopilot
namespace = local.kubernetes_namespace
create_namespace = true
enable_grafana_on_ray_dashboard = var.enable_grafana_on_ray_dashboard
k8s_service_account = local.ray_service_account
depends_on = [module.namespace, module.kuberay-workload-identity]
}

module "kuberay-cluster" {
source = "../../modules/kuberay-cluster"
providers = { helm = helm.rag, kubernetes = kubernetes.rag }
project_id = var.project_id
namespace = local.kubernetes_namespace
enable_gpu = true
gcs_bucket = var.gcs_bucket
autopilot_cluster = local.enable_autopilot
cloudsql_instance_name = local.cloudsql_instance
db_region = local.cloudsql_instance_region
google_service_account = local.ray_service_account
disable_network_policy = var.disable_ray_cluster_network_policy
use_custom_image = true
additional_labels = var.additional_labels

# Implicit dependency
db_secret_name = module.cloudsql.db_secret_name
grafana_host = module.kuberay-monitoring.grafana_uri

# IAP Auth parameters
add_auth = var.ray_dashboard_add_auth
create_brand = var.create_brand
support_email = var.support_email
client_id = var.ray_dashboard_client_id
client_secret = var.ray_dashboard_client_secret
k8s_ingress_name = var.ray_dashboard_k8s_ingress_name
k8s_iap_secret_name = var.ray_dashboard_k8s_iap_secret_name
k8s_managed_cert_name = var.ray_dashboard_k8s_managed_cert_name
k8s_backend_config_name = var.ray_dashboard_k8s_backend_config_name
k8s_backend_service_port = var.ray_dashboard_k8s_backend_service_port
domain = var.ray_dashboard_domain
members_allowlist = var.ray_dashboard_members_allowlist != "" ? split(",", var.ray_dashboard_members_allowlist) : []
depends_on = [module.gcs, module.kuberay-workload-identity]
}

module "inference-server" {
source = "../../modules/inference-service"
providers = { kubernetes = kubernetes.rag }
namespace = local.kubernetes_namespace
additional_labels = var.additional_labels
autopilot_cluster = local.enable_autopilot
depends_on = [module.namespace]
}

module "frontend" {
source = "./frontend"
providers = { helm = helm.rag, kubernetes = kubernetes.rag }
project_id = var.project_id
create_service_account = var.create_rag_service_account
google_service_account = local.rag_service_account
namespace = local.kubernetes_namespace
additional_labels = var.additional_labels
inference_service_endpoint = module.inference-server.inference_service_endpoint
cloudsql_instance = module.cloudsql.instance
cloudsql_instance_region = local.cloudsql_instance_region
db_secret_name = module.cloudsql.db_secret_name
dataset_embeddings_table_name = var.dataset_embeddings_table_name

# IAP Auth parameters
add_auth = var.frontend_add_auth
create_brand = var.create_brand
support_email = var.support_email
client_id = var.frontend_client_id
client_secret = var.frontend_client_secret
k8s_ingress_name = var.frontend_k8s_ingress_name
k8s_managed_cert_name = var.frontend_k8s_managed_cert_name
k8s_iap_secret_name = var.frontend_k8s_iap_secret_name
k8s_backend_config_name = var.frontend_k8s_backend_config_name
k8s_backend_service_name = var.frontend_k8s_backend_service_name
k8s_backend_service_port = var.frontend_k8s_backend_service_port
domain = var.frontend_domain
members_allowlist = var.frontend_members_allowlist != "" ? split(",", var.frontend_members_allowlist) : []
depends_on = [module.namespace]
}

resource "helm_release" "gmp-apps" {
name = "gmp-apps"
provider = helm.rag
chart = "../../charts/gmp-engine/"
namespace = local.kubernetes_namespace
# Timeout is increased to guarantee sufficient scale-up time for Autopilot nodes.
timeout = 1200
depends_on = [module.inference-server, module.frontend]
values = [
"${file("${path.module}/podmonitoring.yaml")}"
]
}

Loading

0 comments on commit 6648640

Please sign in to comment.