Skip to content

Commit

Permalink
updating notebook to use the PostgresVectorStore instead of the custo…
Browse files Browse the repository at this point in the history
…m vector store
  • Loading branch information
german-grandas committed Sep 10, 2024
1 parent 74b6e9d commit e94cab0
Showing 1 changed file with 93 additions and 104 deletions.
197 changes: 93 additions & 104 deletions applications/rag/example_notebooks/rag-kaggle-ray-sql-interactive.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,16 @@
"!unzip -o ~/data/netflix-shows.zip -d /data/netflix-shows"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "c421c932",
"metadata": {},
"outputs": [],
"source": [
"!pip install langchain-google-cloud-sql-pg"
]
},
{
"cell_type": "markdown",
"id": "c7ff518d-f4d2-481b-b408-2c2507565611",
Expand All @@ -52,50 +62,58 @@
"import os\n",
"import uuid\n",
"import ray\n",
"from langchain.document_loaders import ArxivLoader\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"from sentence_transformers import SentenceTransformer\n",
"\n",
"from typing import List\n",
"import torch\n",
"from datasets import load_dataset_builder, load_dataset, Dataset\n",
"from huggingface_hub import snapshot_download\n",
"from google.cloud.sql.connector import Connector, IPTypes\n",
"import sqlalchemy\n",
"\n",
"# initialize parameters\n",
"\n",
"INSTANCE_CONNECTION_NAME = os.environ[\"CLOUDSQL_INSTANCE_CONNECTION_NAME\"]\n",
"print(f\"Your instance connection name is: {INSTANCE_CONNECTION_NAME}\")\n",
"DB_NAME = \"pgvector-database\"\n",
"\n",
"db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
"DB_USER = db_username_file.read()\n",
"db_username_file.close()\n",
"\n",
"db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
"DB_PASS = db_password_file.read()\n",
"db_password_file.close()\n",
"from sentence_transformers import SentenceTransformer\n",
"from langchain.text_splitter import RecursiveCharacterTextSplitter\n",
"\n",
"# initialize Connector object\n",
"connector = Connector()\n",
"from langchain_google_cloud_sql_pg import PostgresEngine, PostgresVectorStore\n",
"from google.cloud.sql.connector import IPTypes\n",
"\n",
"# function to return the database connection object\n",
"def getconn():\n",
" conn = connector.connect(\n",
" INSTANCE_CONNECTION_NAME,\n",
" \"pg8000\",\n",
"# initialize parameters\n",
"GCP_PROJECT_ID = os.environ.get(\"GCP_PROJECT_ID\")\n",
"GCP_CLOUD_SQL_REGION = os.environ.get(\"CLOUDSQL_INSTANCE_REGION\")\n",
"GCP_CLOUD_SQL_INSTANCE = os.environ.get(\"CLOUDSQL_INSTANCE\")\n",
"\n",
"DB_NAME = os.environ.get(\"INSTANCE_CONNECTION_NAME\", \"pgvector-database\")\n",
"VECTOR_EMBEDDINGS_TABLE_NAME = os.environ.get(\"EMBEDDINGS_TABLE_NAME\", \"netflix_reviews_db\")\n",
"CHAT_HISTORY_TABLE_NAME = os.environ.get(\"CHAT_HISTORY_TABLE_NAME\", \"message_store\")\n",
"\n",
"VECTOR_DIMENSION = os.environ.get(\"VECTOR_DIMENSION\", 384)\n",
"\n",
"try:\n",
" db_username_file = open(\"/etc/secret-volume/username\", \"r\")\n",
" DB_USER = db_username_file.read()\n",
" db_username_file.close()\n",
"\n",
" db_password_file = open(\"/etc/secret-volume/password\", \"r\")\n",
" DB_PASS = db_password_file.read()\n",
" db_password_file.close()\n",
"except:\n",
" DB_USER = os.environ.get(\"DB_USERNAME\", \"postgres\")\n",
" DB_PASS = os.environ.get(\"DB_PASS\", \"postgres\")\n",
"\n",
"engine = PostgresEngine.from_instance(\n",
" project_id=GCP_PROJECT_ID,\n",
" region=GCP_CLOUD_SQL_REGION,\n",
" instance=GCP_CLOUD_SQL_INSTANCE,\n",
" database=DB_NAME,\n",
" user=DB_USER,\n",
" password=DB_PASS,\n",
" db=DB_NAME,\n",
" ip_type=IPTypes.PRIVATE\n",
" ip_type=IPTypes.PRIVATE,\n",
")\n",
"\n",
"try:\n",
" engine.init_vectorstore_table(\n",
" VECTOR_EMBEDDINGS_TABLE_NAME,\n",
" vector_size=VECTOR_DIMENSION,\n",
" overwrite_existing=True,\n",
" )\n",
" return conn\n",
"\n",
"# create connection pool with 'creator' argument to our connection object function\n",
"pool = sqlalchemy.create_engine(\n",
" \"postgresql+pg8000://\",\n",
" creator=getconn,\n",
")"
"except Exception as err:\n",
" print(f\"Error: {err}\")"
]
},
{
Expand Down Expand Up @@ -158,9 +176,10 @@
"id": "f7304035-21a4-4017-bce9-aba7e9f81c90",
"metadata": {},
"source": [
"## Generating Vector Embeddings\n",
"## Generating Documents splits\n",
"\n",
"We are ready to begin. Let's first create some code for generating the vector embeddings:"
"We are ready to begin. Let's first create some code for generating the dataset splits:\n",
"\n"
]
},
{
Expand All @@ -170,16 +189,8 @@
"metadata": {},
"outputs": [],
"source": [
"class Embed:\n",
"class Splitter:\n",
" def __init__(self):\n",
" print(\"torch cuda version\", torch.version.cuda)\n",
" device=\"cpu\"\n",
" if torch.cuda.is_available():\n",
" print(\"device cuda found\")\n",
" device=\"cuda\"\n",
"\n",
" print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n",
" self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n",
" self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)\n",
"\n",
" def __call__(self, text_batch: List[str]):\n",
Expand All @@ -191,12 +202,7 @@
" # print(\"len(data)\", len(data), \"len(splits)=\", len(splits))\n",
" chunks.extend(splits)\n",
"\n",
" embeddings = self.transformer.encode(\n",
" chunks,\n",
" batch_size=BATCH_SIZE\n",
" ).tolist()\n",
" print(\"len(chunks)=\", len(chunks), \", len(emb)=\", len(embeddings))\n",
" return {'results':list(zip(chunks, embeddings))}"
" return {'results':chunks}"
]
},
{
Expand Down Expand Up @@ -227,6 +233,7 @@
" \"datasets==2.18.0\",\n",
" \"torch==2.0.1\",\n",
" \"huggingface_hub==0.21.3\",\n",
" \"langchain-google-cloud-sql-pg\"\n",
" ]\n",
" }\n",
")"
Expand Down Expand Up @@ -262,8 +269,8 @@
"print(ds_batch.schema)\n",
"\n",
"# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer\n",
"ds_embed = ds_batch.map_batches(\n",
" Embed,\n",
"ds_splitted = ds_batch.map_batches(\n",
" Splitter,\n",
" compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),\n",
" batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.\n",
" num_gpus=1, # 1 GPU for each actor.\n",
Expand All @@ -287,17 +294,17 @@
"outputs": [],
"source": [
"@ray.remote\n",
"def ray_data_task(ds_embed):\n",
"def ray_data_task(ds_splitted):\n",
" results = []\n",
" for row in ds_embed.iter_rows():\n",
" data_text = row[\"results\"][0][:65535]\n",
" data_emb = row[\"results\"][1]\n",
" for row in ds_splitted.iter_rows():\n",
" data_text = row[\"results\"]\n",
" data_id = str(uuid.uuid4()) \n",
"\n",
" results.append((data_text, data_emb))\n",
" results.append((data_id, data_text))\n",
" \n",
" return results\n",
" \n",
"results = ray.get(ray_data_task.remote(ds_embed))"
"results = ray.get(ray_data_task.remote(ds_splitted))"
]
},
{
Expand All @@ -317,36 +324,25 @@
"metadata": {},
"outputs": [],
"source": [
"from sqlalchemy.ext.declarative import declarative_base\n",
"from sqlalchemy import Column, String, Text, text\n",
"from sqlalchemy.orm import scoped_session, sessionmaker, mapped_column\n",
"from pgvector.sqlalchemy import Vector\n",
"\n",
"\n",
"Base = declarative_base()\n",
"DBSession = scoped_session(sessionmaker())\n",
"\n",
"class TextEmbedding(Base):\n",
" __tablename__ = TABLE_NAME\n",
" id = Column(String(255), primary_key=True)\n",
" text = Column(Text)\n",
" text_embedding = mapped_column(Vector(384))\n",
"\n",
"with pool.connect() as conn:\n",
" conn.execute(text(\"CREATE EXTENSION IF NOT EXISTS vector\"))\n",
" conn.commit() \n",
"print(\"torch cuda version\", torch.version.cuda)\n",
"device=\"cpu\"\n",
"if torch.cuda.is_available():\n",
" print(\"device cuda found\")\n",
" device=\"cuda\"\n",
" \n",
"DBSession.configure(bind=pool, autoflush=False, expire_on_commit=False)\n",
"Base.metadata.drop_all(pool)\n",
"Base.metadata.create_all(pool)\n",
"\n",
"rows = []\n",
"for r in results:\n",
" id = uuid.uuid4() \n",
" rows.append(TextEmbedding(id=id, text=r[0], text_embedding=r[1]))\n",
"\n",
"DBSession.bulk_save_objects(rows)\n",
"DBSession.commit()"
"print (\"reading sentence transformer model from cache path:\", SENTENCE_TRANSFORMER_MODEL_PATH)\n",
"\n",
"embeddings_service = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)\n",
"vector_store = PostgresVectorStore.create_sync(\n",
" engine=engine,\n",
" embedding_service=embeddings_service,\n",
" table_name=VECTOR_EMBEDDINGS_TABLE_NAME,\n",
")\n",
"\n",
"for result in results:\n",
" id = result[0]\n",
" splits = result[1]\n",
" vector_store.add_texts(splits, id)"
]
},
{
Expand All @@ -364,21 +360,14 @@
"metadata": {},
"outputs": [],
"source": [
"with pool.connect() as db_conn:\n",
" # verify results\n",
" transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)\n",
" query_text = \"During my holiday in Marmaris we ate here to fit the food. It's really good\" \n",
" query_emb = transformer.encode(query_text).tolist()\n",
" query_request = \"SELECT id, text, text_embedding, 1 - ('[\" + \",\".join(map(str, query_emb)) + \"]' <=> text_embedding) AS cosine_similarity FROM \" + TABLE_NAME + \" ORDER BY cosine_similarity DESC LIMIT 5;\" \n",
" query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()\n",
" db_conn.commit()\n",
" \n",
" print(\"print query_results, the 1st one is the hit\")\n",
" for row in query_results:\n",
" print(row)\n",
"\n",
"# cleanup connector object\n",
"connector.close()"
"query = \"List the cast of squid game\"\n",
"query_vector = embeddings_service.embed_query(query)\n",
"docs = vector_store.similarity_search_by_vector(query_vector, k=4)\n",
"\n",
"for i, document in enumerate(docs):\n",
" print(f\"Result #{i+1}\")\n",
" print(document.page_content)\n",
" print(\"-\" * 100)"
]
}
],
Expand Down

0 comments on commit e94cab0

Please sign in to comment.