Skip to content

Commit

Permalink
Refactoring example notebooks to handle new cloudsql vector store
Browse files Browse the repository at this point in the history
  • Loading branch information
german-grandas committed Sep 11, 2024
1 parent f1bf05a commit 88fe07d
Show file tree
Hide file tree
Showing 2 changed files with 82 additions and 170 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -61,14 +61,10 @@
"source": [
"import os\n",
"import uuid\n",
"import ray\n",
"\n",
"from typing import List\n",
"import torch\n",
"from datasets import load_dataset_builder, load_dataset, Dataset\n",
"from huggingface_hub import snapshot_download\n",
"from sentence_transformers import SentenceTransformer\n",
"import ray\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n",
"\n",
"from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n",
"from google.cloud.sql.connector import IPTypes\n",
Expand Down Expand Up @@ -135,11 +131,7 @@
"metadata": {},
"outputs": [],
"source": [
"SHARED_DATA_BASEPATH='/data/rag/st'\n",
"SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n",
"SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n",
"SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n",
"SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n",
"\n",
"# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n",
"SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n",
Expand All @@ -153,28 +145,6 @@
"ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function"
]
},
{
"cell_type": "markdown",
"id": "3dc5bc85-dc3b-4622-99a2-f9fc269e753b",
"metadata": {},
"source": [
"Now we will download the sentence transformer model to our GCS bucket:"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "b7a676be-56c6-4c76-8041-9ad05361dd3b",
"metadata": {},
"outputs": [],
"source": [
"# prepare the persistent shared directory to store artifacts needed for the ray workers\n",
"os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n",
"\n",
"# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n",
"snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)"
]
},
{
"cell_type": "markdown",
"id": "f7304035-21a4-4017-bce9-aba7e9f81c90",
Expand All @@ -197,13 +167,11 @@
" def __init__(self):\n",
" self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n",
"\n",
" def __call__(self, text_batch: List[str]):\n",
" def __call__(self, text_batch):\n",
" text = text_batch[\"item\"]\n",
" # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n",
" chunks = []\n",
" for data in text:\n",
" splits = self.splitter.split_text(data)\n",
" # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n",
" chunks.extend(splits)\n",
"\n",
" return {'results':chunks}"
Expand Down Expand Up @@ -272,7 +240,7 @@
"}])\n",
"print(ds_batch.schema)\n",
"\n",
"# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n",
"# Distributed map batches to create chunks out of each row.\n",
"ds_splitted = ds_batch.map_batches(\n",
" Splitter,\n",
" compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n",
Expand Down Expand Up @@ -301,7 +269,7 @@
"def ray_data_task(ds_splitted):\n",
" results = []\n",
" for row in ds_splitted.iter_rows():\n",
" data_text = row[\"results\"].page_content\n",
" data_text = row[\"results\"]\n",
" data_id = str(uuid.uuid4()) \n",
"\n",
" results.append((data_id, data_text))\n",
Expand Down Expand Up @@ -334,9 +302,7 @@
" print(\"device cuda found\")\n",
" device=\"cuda\"\n",
" \n",
"print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n",
"\n",
"embeddings_service = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n",
"embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL, model_kwargs=dict(device=device))\n",
"vector_store = PostgresVectorStore.create_sync(\n",
" engine=engine,\n",
" embedding_service=embeddings_service,\n",
Expand Down
206 changes: 76 additions & 130 deletions applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@
"metadata": {},
"outputs": [],
"source": [
"!pip install ray[default]==2.9.3 kaggle==1.6.6"
"!pip install ray[default]==2.9.3 kaggle==1.6.6 langchain-google-cloud-sql-pg"
]
},
{
Expand Down Expand Up @@ -73,57 +73,62 @@
"\n",
"import os\n",
"import uuid\n",
"\n",
"import ray\n",
"from langchain.document_loaders import ArxivLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from sentence_transformers import SentenceTransformer\n",
"from typing import List\n",
"import torch\n",
"from datasets import load_dataset_builder, load_dataset, Dataset\n",
"from huggingface_hub import snapshot_download\n",
"from google.cloud.sql.connector import Connector, IPTypes\n",
"import sqlalchemy\n",
"from langchain_community.embeddings.huggingface import HuggingFaceEmbeddings\n",
"\n",
"from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n",
"from google.cloud.sql.connector import IPTypes\n",
"\n",
"# initialize parameters\n",
"INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n",
"INSTANCE_CONNECTION_NAME = os.environ.get(\"CLOUDSQL_INSTANCE_CONNECTION_NAME\")\n",
"print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n",
"DB_NAME = \"pgvector-database\"\n",
"\n",
"db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
"DB_USER = db_username_file.read()\n",
"db_username_file.close()\n",
"\n",
"db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
"DB_PASS = db_password_file.read()\n",
"db_password_file.close()\n",
"\n",
"# initialize Connector object\n",
"connector = Connector()\n",
"\n",
"# function to return the database connection object\n",
"def getconn():\n",
" conn = connector.connect(\n",
" INSTANCE_CONNECTION_NAME,\n",
" \"pg8000\",\n",
"cloud_variables = INSTANCE_CONNECTION_NAME.split(\":\")\n",
"\n",
"GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\", cloud_variables[0])\n",
"GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\", cloud_variables[1])\n",
"GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\", cloud_variables[2])\n",
"\n",
"DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n",
"VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n",
"CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n",
"\n",
"VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n",
"\n",
"try:\n",
" db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
" DB_USER = db_username_file.read()\n",
" db_username_file.close()\n",
"\n",
" db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
" DB_PASS = db_password_file.read()\n",
" db_password_file.close()\n",
"except:\n",
" DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n",
" DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n",
"\n",
"engine = PostgresEngine.from_instance(\n",
" project_id=GCP_PROJECT_ID,\n",
" region=GCP_CLOUD_SQL_REGION,\n",
" instance=GCP_CLOUD_SQL_INSTANCE,\n",
" database=DB_NAME,\n",
" user=DB_USER,\n",
" password=DB_PASS,\n",
" db=DB_NAME,\n",
" ip_type=IPTypes.PRIVATE\n",
" ip_type=IPTypes.PRIVATE,\n",
")\n",
"\n",
"try:\n",
" engine.init_vectorstore_table(\n",
" VECTOR_EMBEDDINGS_TABLE_NAME,\n",
" vector_size=VECTOR_DIMENSION,\n",
" overwrite_existing=True,\n",
" )\n",
" return conn\n",
"except Exception as err:\n",
" print(f\"Error: {err}\")\n",
"\n",
"# create connection pool with 'creator' argument to our connection object function\n",
"pool = sqlalchemy.create_engine(\n",
" \"postgresql+pg8000://\",\n",
" creator=getconn,\n",
")\n",
"\n",
"SHARED_DATA_BASEPATH='/data/rag/st'\n",
"SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings\n",
"SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name\n",
"SENTENCE_TRANSFORMER_MODEL_SNAPSHOT=\"ffdcc22a9a5c973ef0470385cef91e1ecb461d9f\" # specific snapshot of the model to use\n",
"SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time\n",
"\n",
"# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.\n",
"SHARED_DATASET_BASE_PATH=\"/data/netflix-shows/\"\n",
"REVIEWS_FILE_NAME=\"netflix_titles.csv\"\n",
Expand All @@ -135,40 +140,18 @@
"DIMENSION = 384 # Embeddings size\n",
"ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function\n",
"\n",
"class Embed:\n",
"class Splitter:\n",
" def __init__(self):\n",
" print(\"torch cuda version\", torch.version.cuda)\n",
" device=\"cpu\"\n",
" if torch.cuda.is_available():\n",
" print(\"device cuda found\")\n",
" device=\"cuda\"\n",
"\n",
" print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n",
" self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n",
" self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n",
"\n",
" def __call__(self, text_batch: List[str]):\n",
" def __call__(self, text_batch):\n",
" text = text_batch[\"item\"]\n",
" # print(\"type(text)=\", type(text), \"type(text_batch)=\", type(text_batch))\n",
" chunks = []\n",
" for data in text:\n",
" splits = self.splitter.split_text(data)\n",
" # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n",
" chunks.extend(splits)\n",
"\n",
" embeddings = self.transformer.encode(\n",
" chunks,\n",
" batch_size=BATCH_SIZE\n",
" ).tolist()\n",
" print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n",
" return {'results':list(zip(chunks, embeddings))}\n",
"\n",
"\n",
"# prepare the persistent shared directory to store artifacts needed for the ray workers\n",
"os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)\n",
"\n",
"# One time download of the sentence transformer model to a shared persistent storage available to the ray workers\n",
"snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)\n",
" return {'results':chunks}\n",
"\n",
"# Process the dataset first, wrap the csv file contents into a Ray dataset\n",
"ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)\n",
Expand All @@ -184,81 +167,44 @@
"}])\n",
"print(ds_batch.schema)\n",
"\n",
"# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n",
"ds_embed = ds_batch.map_batches(\n",
" Embed,\n",
"# Distributed map batches to create chunks out of each row.\n",
"ds_splitted = ds_batch.map_batches(\n",
" Splitter,\n",
" compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n",
" batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n",
" num_gpus=1, # 1 GPU for each actor.\n",
" # num_cpus=1,\n",
")\n",
"\n",
"# Use this block for debug purpose to inspect the embeddings and raw text\n",
"# print(\"Embeddings ray dataset\", ds_embed.schema)\n",
"# for output in ds_embed.iter_rows():\n",
"# # restrict the text string to be less than 65535\n",
"# data_text = output[\"results\"][0][:65535]\n",
"# # vector data pass in needs to be a string \n",
"# data_emb = \",\".join(map(str, output[\"results\"][1]))\n",
"# data_emb = \"[\" + data_emb + \"]\"\n",
"# print (\"raw text:\", data_text, \", emdeddings:\", data_emb)\n",
"\n",
"# print(\"Embeddings ray dataset\", ds_embed.schema)\n",
"\n",
"data_text = \"\"\n",
"data_emb = \"\"\n",
"\n",
"with pool.connect() as db_conn:\n",
" db_conn.execute(\n",
" sqlalchemy.text(\n",
" \"CREATE EXTENSION IF NOT EXISTS vector;\"\n",
" )\n",
" )\n",
" db_conn.commit()\n",
"print(\"torch cuda version\", torch.version.cuda)\n",
"device=\"cpu\"\n",
"if torch.cuda.is_available():\n",
" print(\"device cuda found\")\n",
" device=\"cuda\"\n",
" \n",
"embeddings_service = HuggingFaceEmbeddings(model_name=SENTENCE_TRANSFORMER_MODEL, model_kwargs=dict(device=device))\n",
"vector_store = PostgresVectorStore.create_sync(\n",
" engine=engine,\n",
" embedding_service=embeddings_service,\n",
" table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n",
")\n",
"\n",
" create_table_query = \"CREATE TABLE IF NOT EXISTS \" + TABLE_NAME + \" ( id VARCHAR(255) NOT NULL, text TEXT NOT NULL, text_embedding vector(384) NOT NULL, PRIMARY KEY (id));\"\n",
" db_conn.execute(\n",
" sqlalchemy.text(create_table_query)\n",
" )\n",
" # commit transaction (SQLAlchemy v2.X.X is commit as you go)\n",
" db_conn.commit()\n",
" print(\"Created table=\", TABLE_NAME)\n",
" \n",
" query_text = \"INSERT INTO \" + TABLE_NAME + \" (id, text, text_embedding) VALUES (:id, :text, :text_embedding)\"\n",
" insert_stmt = sqlalchemy.text(query_text)\n",
" for output in ds_embed.iter_rows():\n",
" # print (\"type of embeddings\", type(output[\"results\"][1]), \"len embeddings\", len(output[\"results\"][1]))\n",
" # restrict the text string to be less than 65535\n",
" data_text = output[\"results\"][0][:65535]\n",
" # vector data pass in needs to be a string \n",
" data_emb = \",\".join(map(str, output[\"results\"][1]))\n",
" data_emb = \"[\" + data_emb + \"]\"\n",
" # print(\"text_embedding is \", data_emb)\n",
"for output in ds_splitted.iter_rows():\n",
" id = uuid.uuid4()\n",
" db_conn.execute(insert_stmt, parameters={\"id\": id, \"text\": data_text, \"text_embedding\": data_emb})\n",
" splits = output[\"results\"]\n",
" vector_store.add_texts(splits, id)\n",
"\n",
" # batch commit transactions\n",
" db_conn.commit()\n",
"\n",
" # query and fetch table\n",
" query_text = \"SELECT * FROM \" + TABLE_NAME\n",
" results = db_conn.execute(sqlalchemy.text(query_text)).fetchall()\n",
" # for row in results:\n",
" # print(row)\n",
"#Validate results\n",
"query = \"List the cast of squid game\"\n",
"query_vector = embeddings_service.embed_query(query)\n",
"docs = vector_store.similarity_search_by_vector(query_vector, k=4)\n",
"\n",
" # verify results\n",
" transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n",
" query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n",
" query_emb = transformer.encode(query_text).tolist()\n",
" query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n",
" query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n",
" db_conn.commit()\n",
" print(\"print query_results, the 1st one is the hit\")\n",
" for row in query_results:\n",
" print(row)\n",
"\n",
"# cleanup connector object\n",
"connector.close()\n",
"for i, document in enumerate(docs):\n",
" print(f\"Result #{i+1}\")\n",
" print(document.page_content)\n",
" print(\"-\" * 100)\n",
" \n",
"print (\"end job\")"
]
},
Expand Down

0 comments on commit 88fe07d

Please sign in to comment.