Skip to content

Commit

Permalink
Add tests for RAG prompt (#511)
Browse files Browse the repository at this point in the history
Signed-off-by: Andrew Sy Kim <[email protected]>
  • Loading branch information
andrewsykim authored and ryanaoleary committed Apr 3, 2024
1 parent 364cd2f commit 10284be
Show file tree
Hide file tree
Showing 7 changed files with 9,098 additions and 10 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@
"outputs": [],
"source": [
"# Create a directory to package the contents that need to be downloaded in ray worker\n",
"! mkdir -p test"
"! mkdir -p rag-app"
]
},
{
Expand All @@ -58,8 +58,8 @@
"metadata": {},
"outputs": [],
"source": [
"%%writefile test/test.py\n",
"# Comment out the above line if you want to see notebook print out, but the line is required for the actual ray job (the test.py is downloaded by the ray workers)\n",
"%%writefile rag-app/job.py\n",
"# Comment out the above line if you want to see notebook print out, but the line is required for the actual ray job (the job.py is downloaded by the ray workers)\n",
"\n",
"import os\n",
"import uuid\n",
Expand Down Expand Up @@ -276,10 +276,10 @@
"\n",
"start_time = time.time()\n",
"job_id = client.submit_job(\n",
" entrypoint=\"python test.py\",\n",
" entrypoint=\"python job.py\",\n",
" # Path to the local directory that contains the entrypoint file.\n",
" runtime_env={\n",
" \"working_dir\": \"/home/jovyan/test\", # upload the local working directory to ray workers\n",
" \"working_dir\": \"/home/jovyan/rag-app\", # upload the local working directory to ray workers\n",
" }\n",
")\n",
"\n",
Expand Down
8,810 changes: 8,810 additions & 0 deletions applications/rag/tests/netflix_titles.csv

Large diffs are not rendered by default.

192 changes: 192 additions & 0 deletions applications/rag/tests/ray_job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
# This script is copied from the applications/rag/example_notebooks/rag-kaggle-ray-sql-latest.ipynb for testing purposes
# TODO: remove this script and execute the notebook directly with nbconvert.

import os
import uuid
import ray
from langchain.document_loaders import ArxivLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from sentence_transformers import SentenceTransformer
from typing import List
import torch
from datasets import load_dataset_builder, load_dataset, Dataset
from huggingface_hub import snapshot_download
from google.cloud.sql.connector import Connector, IPTypes
import sqlalchemy

# initialize parameters
INSTANCE_CONNECTION_NAME = os.environ["CLOUDSQL_INSTANCE_CONNECTION_NAME"]
print(f"Your instance connection name is: {INSTANCE_CONNECTION_NAME}")
DB_NAME = "pgvector-database"

db_username_file = open("/etc/secret-volume/username", "r")
DB_USER = db_username_file.read()
db_username_file.close()

db_password_file = open("/etc/secret-volume/password", "r")
DB_PASS = db_password_file.read()
db_password_file.close()

# initialize Connector object
connector = Connector()

# function to return the database connection object
def getconn():
conn = connector.connect(
INSTANCE_CONNECTION_NAME,
"pg8000",
user=DB_USER,
password=DB_PASS,
db=DB_NAME,
ip_type=IPTypes.PRIVATE
)
return conn

# create connection pool with 'creator' argument to our connection object function
pool = sqlalchemy.create_engine(
"postgresql+pg8000://",
creator=getconn,
)

SHARED_DATA_BASEPATH='/data/rag/st'
SENTENCE_TRANSFORMER_MODEL = 'intfloat/multilingual-e5-small' # Transformer to use for converting text chunks to vector embeddings
SENTENCE_TRANSFORMER_MODEL_PATH_NAME='models--intfloat--multilingual-e5-small' # the downloaded model path takes this form for a given model name
SENTENCE_TRANSFORMER_MODEL_SNAPSHOT="ffdcc22a9a5c973ef0470385cef91e1ecb461d9f" # specific snapshot of the model to use
SENTENCE_TRANSFORMER_MODEL_PATH = SHARED_DATA_BASEPATH + '/' + SENTENCE_TRANSFORMER_MODEL_PATH_NAME + '/snapshots/' + SENTENCE_TRANSFORMER_MODEL_SNAPSHOT # the path where the model is downloaded one time

# the dataset has been pre-dowloaded to the GCS bucket as part of the notebook in the cell above. Ray workers will find the dataset readily mounted.
SHARED_DATASET_BASE_PATH="/data/netflix-shows/"
REVIEWS_FILE_NAME="netflix_titles.csv"

BATCH_SIZE = 100
CHUNK_SIZE = 1000 # text chunk sizes which will be converted to vector embeddings
CHUNK_OVERLAP = 10
TABLE_NAME = 'netflix_reviews_db' # CloudSQL table name
DIMENSION = 384 # Embeddings size
ACTOR_POOL_SIZE = 1 # number of actors for the distributed map_batches function

class Embed:
def __init__(self):
print("torch cuda version", torch.version.cuda)
device="cpu"
if torch.cuda.is_available():
print("device cuda found")
device="cuda"

print ("reading sentence transformer model from cache path:", SENTENCE_TRANSFORMER_MODEL_PATH)
self.transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL_PATH, device=device)
self.splitter = RecursiveCharacterTextSplitter(chunk_size=CHUNK_SIZE, chunk_overlap=CHUNK_OVERLAP, length_function=len)

def __call__(self, text_batch: List[str]):
text = text_batch["item"]
# print("type(text)=", type(text), "type(text_batch)=", type(text_batch))
chunks = []
for data in text:
splits = self.splitter.split_text(data)
# print("len(data)", len(data), "len(splits)=", len(splits))
chunks.extend(splits)

embeddings = self.transformer.encode(
chunks,
batch_size=BATCH_SIZE
).tolist()
print("len(chunks)=", len(chunks), ", len(emb)=", len(embeddings))
return {'results':list(zip(chunks, embeddings))}


# prepare the persistent shared directory to store artifacts needed for the ray workers
os.makedirs(SHARED_DATA_BASEPATH, exist_ok=True)

# One time download of the sentence transformer model to a shared persistent storage available to the ray workers
snapshot_download(repo_id=SENTENCE_TRANSFORMER_MODEL, revision=SENTENCE_TRANSFORMER_MODEL_SNAPSHOT, cache_dir=SHARED_DATA_BASEPATH)

# Process the dataset first, wrap the csv file contents into a Ray dataset
ray_ds = ray.data.read_csv(SHARED_DATASET_BASE_PATH + REVIEWS_FILE_NAME)
print(ray_ds.schema)

# Distributed flat map to extract the raw text fields.
ds_batch = ray_ds.flat_map(lambda row: [{
'item': "This is a " + str(row["type"]) + " in " + str(row["country"]) + " called " + str(row["title"]) +
" added at " + str(row["date_added"]) + " whose director is " + str(row["director"]) +
" and with cast: " + str(row["cast"]) + " released at " + str(row["release_year"]) +
". Its rating is: " + str(row['rating']) + ". Its duration is " + str(row["duration"]) +
". Its description is " + str(row['description']) + "."
}])
print(ds_batch.schema)

# Distributed map batches to create chunks out of each row, and fetch the vector embeddings by running inference on the sentence transformer
ds_embed = ds_batch.map_batches(
Embed,
compute=ray.data.ActorPoolStrategy(size=ACTOR_POOL_SIZE),
batch_size=BATCH_SIZE, # Large batch size to maximize GPU utilization.
num_gpus=1, # 1 GPU for each actor.
# num_cpus=1,
)

# Use this block for debug purpose to inspect the embeddings and raw text
# print("Embeddings ray dataset", ds_embed.schema)
# for output in ds_embed.iter_rows():
# # restrict the text string to be less than 65535
# data_text = output["results"][0][:65535]
# # vector data pass in needs to be a string
# data_emb = ",".join(map(str, output["results"][1]))
# data_emb = "[" + data_emb + "]"
# print ("raw text:", data_text, ", emdeddings:", data_emb)

# print("Embeddings ray dataset", ds_embed.schema)

data_text = ""
data_emb = ""

with pool.connect() as db_conn:
db_conn.execute(
sqlalchemy.text(
"CREATE EXTENSION IF NOT EXISTS vector;"
)
)
db_conn.commit()

create_table_query = "CREATE TABLE IF NOT EXISTS " + TABLE_NAME + " ( id VARCHAR(255) NOT NULL, text TEXT NOT NULL, text_embedding vector(384) NOT NULL, PRIMARY KEY (id));"
db_conn.execute(
sqlalchemy.text(create_table_query)
)
# commit transaction (SQLAlchemy v2.X.X is commit as you go)
db_conn.commit()
print("Created table=", TABLE_NAME)

query_text = "INSERT INTO " + TABLE_NAME + " (id, text, text_embedding) VALUES (:id, :text, :text_embedding)"
insert_stmt = sqlalchemy.text(query_text)
for output in ds_embed.iter_rows():
# print ("type of embeddings", type(output["results"][1]), "len embeddings", len(output["results"][1]))
# restrict the text string to be less than 65535
data_text = output["results"][0][:65535]
# vector data pass in needs to be a string
data_emb = ",".join(map(str, output["results"][1]))
data_emb = "[" + data_emb + "]"
# print("text_embedding is ", data_emb)
id = uuid.uuid4()
db_conn.execute(insert_stmt, parameters={"id": id, "text": data_text, "text_embedding": data_emb})

# batch commit transactions
db_conn.commit()

# query and fetch table
query_text = "SELECT * FROM " + TABLE_NAME
results = db_conn.execute(sqlalchemy.text(query_text)).fetchall()
# for row in results:
# print(row)

# verify results
transformer = SentenceTransformer(SENTENCE_TRANSFORMER_MODEL)
query_text = "During my holiday in Marmaris we ate here to fit the food. It's really good"
query_emb = transformer.encode(query_text).tolist()
query_request = "SELECT id, text, text_embedding, 1 - ('[" + ",".join(map(str, query_emb)) + "]' <=> text_embedding) AS cosine_similarity FROM " + TABLE_NAME + " ORDER BY cosine_similarity DESC LIMIT 5;"
query_results = db_conn.execute(sqlalchemy.text(query_request)).fetchall()
db_conn.commit()
print("print query_results, the 1st one is the hit")
for row in query_results:
print(row)

# cleanup connector object
connector.close()
print ("end job")
3 changes: 3 additions & 0 deletions applications/rag/tests/test_frontend.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,9 @@ def test_frontend_up(rag_frontend_url):
r.raise_for_status()
print("Rag frontend is up.")

assert "Submit your query below." in r.content.decode('utf-8')
assert "Enable Filters" in r.content.decode('utf-8')

hub_url = "http://" + sys.argv[1]

test_frontend_up(hub_url)
65 changes: 65 additions & 0 deletions applications/rag/tests/test_rag.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,65 @@
import json
import sys
import requests

def test_prompts(prompt_url):
testcases = [
{
"prompt": "List the cast of Squid Game",
"expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..",
"expected_substrings": ["Lee Jung-jae", "Park Hae-soo", "Wi Ha-jun", "Oh Young-soo", "Jung Ho-yeon", "Heo Sung-tae", "Kim Joo-ryoung", "Tripathi Anupam", "You Seong-joo", "Lee You-mi"],
},
{
"prompt": "When was Squid Game released?",
"expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..",
"expected_substrings": ["September 17, 2021"],
},
{
"prompt": "What is the rating of Squid Game?",
"expected_context": "This is a TV Show in called Squid Game added at September 17, 2021 whose director is and with cast: Lee Jung-jae, Park Hae-soo, Wi Ha-jun, Oh Young-soo, Jung Ho-yeon, Heo Sung-tae, Kim Joo-ryoung, Tripathi Anupam, You Seong-joo, Lee You-mi released at 2021. Its rating is: TV-MA. Its duration is 1 Season. Its description is Hundreds of cash-strapped players accept a strange invitation to compete in children's games. Inside, a tempting prize awaits — with deadly high stakes..",
"expected_substrings": ["TV-MA"],
},
{
"prompt": "List the cast of Avatar: The Last Airbender",
"expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..",
"expected_substrings": ["Zach Tyler", "Mae Whitman", "Jack De Sena", "Dee Bradley Baker", "Dante Basco", "Jessie Flower", "Mako Iwamatsu"],
},
{
"prompt": "When was Avatar: The Last Airbender added on Netflix?",
"expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..",
"expected_substrings": ["May 15, 2020"],
},
{
"prompt": "What is the rating of Avatar: The Last Airbender?",
"expected_context": "This is a TV Show in United States called Avatar: The Last Airbender added at May 15, 2020 whose director is and with cast: Zach Tyler, Mae Whitman, Jack De Sena, Dee Bradley Baker, Dante Basco, Jessie Flower, Mako Iwamatsu released at 2007. Its rating is: TV-Y7. Its duration is 3 Seasons. Its description is Siblings Katara and Sokka wake young Aang from a long hibernation and learn he's an Avatar, whose air-bending powers can defeat the evil Fire Nation..",
"expected_substrings": ["TV-Y7"],
},
]

for testcase in testcases:
prompt = testcase["prompt"]
expected_context = testcase["expected_context"]
expected_substrings = testcase["expected_substrings"]

print(f"Testing prompt: {prompt}")
data = {"prompt": prompt}
json_payload = json.dumps(data)

headers = {'Content-Type': 'application/json'}
response = requests.post(prompt_url, data=json_payload, headers=headers)
response.raise_for_status()

response = response.json()
context = response['response']['context']
text = response['response']['text']
user_prompt = response['response']['user_prompt']

print(f"Reply: {text}")

assert user_prompt == prompt, f"unexpected user prompt: {user_prompt} != {prompt}"
assert context == expected_context, f"unexpected context: {context} != {expected_context}"

for substring in expected_substrings:
assert substring in text, f"substring {substring} not in response:\n {text}"

test_prompts(sys.argv[1])
20 changes: 19 additions & 1 deletion cloudbuild.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ steps:
-var=project_id=$PROJECT_ID \
-var=network_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-$_AUTOPILOT_CLUSTER \
-var=subnetwork_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-$_AUTOPILOT_CLUSTER \
-var=subnetwork_region=$_REGION \
-var=cluster_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-cluster \
-var=autopilot_cluster=$_AUTOPILOT_CLUSTER \
-var=cluster_location=$_REGION \
Expand Down Expand Up @@ -146,6 +147,7 @@ steps:
-var-file=workloads-without-iap.example.tfvars \
-var=project_id=$PROJECT_ID \
-var=cluster_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-cluster \
-var=cluster_location=$_REGION \
-var=kubernetes_namespace=ml-$SHORT_SHA-$_BUILD_ID \
-var=workload_identity_service_account=jupyter-sa-$SHORT_SHA-$_BUILD_ID \
-var=gcs_bucket=gke-aieco-jupyter-$SHORT_SHA-$_BUILD_ID \
Expand Down Expand Up @@ -176,6 +178,7 @@ steps:
-var-file=workloads-without-iap.example.tfvars \
-var=project_id=$PROJECT_ID \
-var=cluster_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-cluster \
-var=cluster_location=$_REGION \
-var=kubernetes_namespace=ml-$SHORT_SHA-$_BUILD_ID \
-var=workload_identity_service_account=jupyter-sa-$SHORT_SHA-$_BUILD_ID \
-var=gcs_bucket=gke-aieco-jupyter-$SHORT_SHA-$_BUILD_ID \
Expand Down Expand Up @@ -210,6 +213,7 @@ steps:
-var=frontend_add_auth=false \
-var=project_id=$PROJECT_ID \
-var=cluster_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-cluster \
-var=cluster_location=$_REGION \
-var=kubernetes_namespace=rag-$SHORT_SHA-$_BUILD_ID \
-var=gcs_bucket=gke-aieco-rag-$SHORT_SHA-$_BUILD_ID \
-var=ray_service_account=ray-sa-$SHORT_SHA-$_BUILD_ID \
Expand Down Expand Up @@ -249,6 +253,14 @@ steps:
cd /workspace/applications/rag/tests
python3 test_frontend.py "127.0.0.1:8081"
echo "pass" > /workspace/rag_frontend_result.txt
# Upload locally stored netflix dataset to GCS bucket mounted as /data
gsutil cp ./netflix_titles.csv gs://gke-aieco-rag-$SHORT_SHA-$_BUILD_ID/netflix-shows/netflix_titles.csv
ray job submit --working-dir . --address=http://127.0.0.1:8265 -- python ray_job.py
python3 test_rag.py "http://127.0.0.1:8081/prompt"
echo "pass" > /workspace/rag_prompt_result.txt
allowFailure: true
waitFor: ['cleanup jupyterhub', 'cleanup ray cluster']

Expand All @@ -269,6 +281,7 @@ steps:
-var=frontend_add_auth=false \
-var=project_id=$PROJECT_ID \
-var=cluster_name=ml-$SHORT_SHA-$_PR_NUMBER-$_BUILD_ID-cluster \
-var=cluster_location=$_REGION \
-var=kubernetes_namespace=rag-$SHORT_SHA-$_BUILD_ID \
-var=gcs_bucket=gke-aieco-rag-$SHORT_SHA-$_BUILD_ID \
-var=ray_service_account=ray-sa-$SHORT_SHA-$_BUILD_ID \
Expand Down Expand Up @@ -350,10 +363,15 @@ steps:
exit 1
fi
if [[ $(cat /workspace/rag_prompt_result.txt) != "pass" ]]; then
echo "rag frontend test failed"
exit 1
fi
waitFor: ['cleanup gke cluster']

substitutions:
_REGION: us-central1
_REGION: us-east4
_USER_NAME: github
_AUTOPILOT_CLUSTER: "false"
_BUILD_ID: ${BUILD_ID:0:8}
Expand Down
Loading

0 comments on commit 10284be

Please sign in to comment.