From 13b2e252f27498c18c5e038d45b057e996d06bf8 Mon Sep 17 00:00:00 2001 From: Josh Reini Date: Wed, 4 Sep 2024 12:44:26 -0400 Subject: [PATCH] Add files via upload --- recipes/trulens/app.py | 52 ++++++++++++ recipes/trulens/base.py | 132 +++++++++++++++++++++++++++++++ recipes/trulens/feedback.py | 46 +++++++++++ recipes/trulens/requirements.txt | 8 ++ recipes/trulens/vector_store.py | 61 ++++++++++++++ 5 files changed, 299 insertions(+) create mode 100644 recipes/trulens/app.py create mode 100644 recipes/trulens/base.py create mode 100644 recipes/trulens/feedback.py create mode 100644 recipes/trulens/requirements.txt create mode 100644 recipes/trulens/vector_store.py diff --git a/recipes/trulens/app.py b/recipes/trulens/app.py new file mode 100644 index 0000000..436a76a --- /dev/null +++ b/recipes/trulens/app.py @@ -0,0 +1,52 @@ +__import__('pysqlite3') +import sys +sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') + +import streamlit as st +import trulens.dashboard.streamlit as trulens_st +from trulens.core import TruSession + +from base import rag, filtered_rag, tru_rag, filtered_tru_rag + +st.set_page_config( + page_title="Use TruLens in Streamlit", + page_icon="🦑", +) + +st.title("TruLens ❤️ Streamlit") + +st.write("Learn about the Pacific Northwest, and view tracing & evaluation metrics powered by TruLens 🦑.") + +tru = TruSession() + +with_filters = st.toggle("Use Context Filter Guardrails", value=False) + +def generate_response(input_text): + if with_filters: + app = filtered_tru_rag + with filtered_tru_rag as recording: + response = filtered_rag.query(input_text) + else: + app = tru_rag + with tru_rag as recording: + response = rag.query(input_text) + + record = recording.get() + + return record, response + +with st.form("my_form"): + text = st.text_area( + "Enter text:", "When was the University of Washington founded?" + ) + submitted = st.form_submit_button("Submit") + if submitted: + record, response = generate_response(text) + st.info(response) + +if submitted: + with st.expander("See the trace of this record 👀"): + trulens_st.trulens_trace(record=record) + + trulens_st.trulens_feedback(record=record) + diff --git a/recipes/trulens/base.py b/recipes/trulens/base.py new file mode 100644 index 0000000..a22cedf --- /dev/null +++ b/recipes/trulens/base.py @@ -0,0 +1,132 @@ +__import__('pysqlite3') +import sys +sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') + +import streamlit as st +from openai import OpenAI +import numpy as np + +from trulens.core import TruSession +from trulens.core.guardrails.base import context_filter +from trulens.apps.custom import instrument +from trulens.apps.custom import TruCustomApp +from trulens.providers.openai import OpenAI as OpenAIProvider +from trulens.core import Feedback +from trulens.core import Select +from trulens.core.guardrails.base import context_filter + +from feedback import feedbacks, f_guardrail +from vector_store import vector_store + +from dotenv import load_dotenv + +load_dotenv() + +oai_client = OpenAI() + +tru = TruSession() + +class RAG_from_scratch: + @instrument + def retrieve(self, query: str) -> list: + """ + Retrieve relevant text from vector store. + """ + results = vector_store.query(query_texts=query, n_results=4) + # Flatten the list of lists into a single list + return [doc for sublist in results["documents"] for doc in sublist] + + @instrument + def generate_completion(self, query: str, context_str: list) -> str: + """ + Generate answer from context. + """ + completion = ( + oai_client.chat.completions.create( + model="gpt-3.5-turbo", + temperature=0, + messages=[ + { + "role": "user", + "content": f"We have provided context information below. \n" + f"---------------------\n" + f"{context_str}" + f"\n---------------------\n" + f"First, say hello and that you're happy to help. \n" + f"\n---------------------\n" + f"Then, given this information, please answer the question: {query}", + } + ], + ) + .choices[0] + .message.content + ) + return completion + + @instrument + def query(self, query: str) -> str: + context_str = self.retrieve(query) + completion = self.generate_completion(query, context_str) + return completion + +class filtered_RAG_from_scratch: + @instrument + @context_filter(f_guardrail, 0.75, keyword_for_prompt="query") + def retrieve(self, query: str) -> list: + """ + Retrieve relevant text from vector store. + """ + results = vector_store.query(query_texts=query, n_results=4) + return [doc for sublist in results["documents"] for doc in sublist] + + @instrument + def generate_completion(self, query: str, context_str: list) -> str: + """ + Generate answer from context. + """ + completion = ( + oai_client.chat.completions.create( + model="gpt-3.5-turbo", + temperature=0, + messages=[ + { + "role": "user", + "content": f"We have provided context information below. \n" + f"---------------------\n" + f"{context_str}" + f"\n---------------------\n" + f"Given this information, please answer the question: {query}", + } + ], + ) + .choices[0] + .message.content + ) + return completion + + @instrument + def query(self, query: str) -> str: + context_str = self.retrieve(query=query) + completion = self.generate_completion( + query=query, context_str=context_str + ) + return completion + + +filtered_rag = filtered_RAG_from_scratch() + +rag = RAG_from_scratch() + +tru_rag = TruCustomApp( + rag, + app_name="RAG", + app_version="v1", + feedbacks=feedbacks, +) + +filtered_tru_rag = TruCustomApp( + filtered_rag, + app_name="RAG", + app_version="v2", + feedbacks=feedbacks, +) \ No newline at end of file diff --git a/recipes/trulens/feedback.py b/recipes/trulens/feedback.py new file mode 100644 index 0000000..8a83788 --- /dev/null +++ b/recipes/trulens/feedback.py @@ -0,0 +1,46 @@ +__import__('pysqlite3') +import sys +sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') + +import numpy as np +from trulens.core import Feedback +from trulens.core import Select +from trulens.providers.openai import OpenAI as OpenAIProvider + +from dotenv import load_dotenv + +load_dotenv() + +provider = OpenAIProvider(model_engine="gpt-4o-mini") + +# Define a groundedness feedback function +f_groundedness = ( + Feedback( + provider.groundedness_measure_with_cot_reasons, name="Groundedness" + ) + .on(Select.RecordCalls.retrieve.rets.collect()) + .on_output() +) +# Question/answer relevance between overall question and answer. +f_answer_relevance = ( + Feedback(provider.relevance_with_cot_reasons, name="Answer Relevance") + .on_input() + .on_output() +) + +# Context relevance between question and each context chunk. +f_context_relevance = ( + Feedback( + provider.context_relevance_with_cot_reasons, name="Context Relevance" + ) + .on_input() + .on(Select.RecordCalls.retrieve.rets[:]) + .aggregate(np.mean) # choose a different aggregation method if you wish +) + +feedbacks = [f_groundedness, f_answer_relevance, f_context_relevance] + +# note: feedback function used for guardrail must only return a score, not also reasons +f_guardrail = Feedback( + provider.context_relevance, name="Context Relevance" +) diff --git a/recipes/trulens/requirements.txt b/recipes/trulens/requirements.txt new file mode 100644 index 0000000..f6382eb --- /dev/null +++ b/recipes/trulens/requirements.txt @@ -0,0 +1,8 @@ +pip==24.2 +openai +pysqlite3-binary +chromadb +trulens-core @ git+https://github.com/truera/trulens#egg=trulens-core&subdirectory=src/core/ +trulens-feedback @ git+https://github.com/truera/trulens#egg=trulens-feedback&subdirectory=src/feedback/ +trulens-providers-openai @ git+https://github.com/truera/trulens#egg=trulens-providers-openai&subdirectory=src/providers/openai/ +trulens-dashboard @ git+https://github.com/truera/trulens#egg=trulens-dashboard&subdirectory=src/dashboard/ \ No newline at end of file diff --git a/recipes/trulens/vector_store.py b/recipes/trulens/vector_store.py new file mode 100644 index 0000000..8873515 --- /dev/null +++ b/recipes/trulens/vector_store.py @@ -0,0 +1,61 @@ +__import__('pysqlite3') +import sys +sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') + +import openai +import os +import chromadb +from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction + +from dotenv import load_dotenv + +load_dotenv() + +uw_info = """ +The University of Washington, founded in 1861 in Seattle, is a public research university +with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell. +As the flagship institution of the six public universities in Washington state, +UW encompasses over 500 buildings and 20 million square feet of space, +including one of the largest library systems in the world. +""" + +wsu_info = """ +Washington State University, commonly known as WSU, founded in 1890, is a public research university in Pullman, Washington. +With multiple campuses across the state, it is the state's second largest institution of higher education. +WSU is known for its programs in veterinary medicine, agriculture, engineering, architecture, and pharmacy. +""" + +seattle_info = """ +Seattle, a city on Puget Sound in the Pacific Northwest, is surrounded by water, mountains and evergreen forests, and contains thousands of acres of parkland. +It's home to a large tech industry, with Microsoft and Amazon headquartered in its metropolitan area. +The futuristic Space Needle, a legacy of the 1962 World's Fair, is its most iconic landmark. +""" + +starbucks_info = """ +Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington. +As the world's largest coffeehouse chain, Starbucks is seen to be the main representation of the United States' second wave of coffee culture. +""" + +newzealand_info = """ +New Zealand is an island country located in the southwestern Pacific Ocean. It comprises two main landmasses—the North Island and the South Island—and over 700 smaller islands. +The country is known for its stunning landscapes, ranging from lush forests and mountains to beaches and lakes. New Zealand has a rich cultural heritage, with influences from +both the indigenous Māori people and European settlers. The capital city is Wellington, while the largest city is Auckland. New Zealand is also famous for its adventure tourism, +including activities like bungee jumping, skiing, and hiking. +""" + +embedding_function = OpenAIEmbeddingFunction( + api_key=os.environ.get("OPENAI_API_KEY"), + model_name="text-embedding-ada-002", +) + + +chroma_client = chromadb.Client() +vector_store = chroma_client.get_or_create_collection( + name="Washington", embedding_function=embedding_function +) + +vector_store.add("uw_info", documents=uw_info) +vector_store.add("wsu_info", documents=wsu_info) +vector_store.add("seattle_info", documents=seattle_info) +vector_store.add("starbucks_info", documents=starbucks_info) +vector_store.add("newzealand_info", documents=newzealand_info) \ No newline at end of file