-
Notifications
You must be signed in to change notification settings - Fork 40
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
1 parent
45f5326
commit 13b2e25
Showing
5 changed files
with
299 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,52 @@ | ||
__import__('pysqlite3') | ||
import sys | ||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | ||
|
||
import streamlit as st | ||
import trulens.dashboard.streamlit as trulens_st | ||
from trulens.core import TruSession | ||
|
||
from base import rag, filtered_rag, tru_rag, filtered_tru_rag | ||
|
||
st.set_page_config( | ||
page_title="Use TruLens in Streamlit", | ||
page_icon="🦑", | ||
) | ||
|
||
st.title("TruLens ❤️ Streamlit") | ||
|
||
st.write("Learn about the Pacific Northwest, and view tracing & evaluation metrics powered by TruLens 🦑.") | ||
|
||
tru = TruSession() | ||
|
||
with_filters = st.toggle("Use Context Filter Guardrails", value=False) | ||
|
||
def generate_response(input_text): | ||
if with_filters: | ||
app = filtered_tru_rag | ||
with filtered_tru_rag as recording: | ||
response = filtered_rag.query(input_text) | ||
else: | ||
app = tru_rag | ||
with tru_rag as recording: | ||
response = rag.query(input_text) | ||
|
||
record = recording.get() | ||
|
||
return record, response | ||
|
||
with st.form("my_form"): | ||
text = st.text_area( | ||
"Enter text:", "When was the University of Washington founded?" | ||
) | ||
submitted = st.form_submit_button("Submit") | ||
if submitted: | ||
record, response = generate_response(text) | ||
st.info(response) | ||
|
||
if submitted: | ||
with st.expander("See the trace of this record 👀"): | ||
trulens_st.trulens_trace(record=record) | ||
|
||
trulens_st.trulens_feedback(record=record) | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,132 @@ | ||
__import__('pysqlite3') | ||
import sys | ||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | ||
|
||
import streamlit as st | ||
from openai import OpenAI | ||
import numpy as np | ||
|
||
from trulens.core import TruSession | ||
from trulens.core.guardrails.base import context_filter | ||
from trulens.apps.custom import instrument | ||
from trulens.apps.custom import TruCustomApp | ||
from trulens.providers.openai import OpenAI as OpenAIProvider | ||
from trulens.core import Feedback | ||
from trulens.core import Select | ||
from trulens.core.guardrails.base import context_filter | ||
|
||
from feedback import feedbacks, f_guardrail | ||
from vector_store import vector_store | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
oai_client = OpenAI() | ||
|
||
tru = TruSession() | ||
|
||
class RAG_from_scratch: | ||
@instrument | ||
def retrieve(self, query: str) -> list: | ||
""" | ||
Retrieve relevant text from vector store. | ||
""" | ||
results = vector_store.query(query_texts=query, n_results=4) | ||
# Flatten the list of lists into a single list | ||
return [doc for sublist in results["documents"] for doc in sublist] | ||
|
||
@instrument | ||
def generate_completion(self, query: str, context_str: list) -> str: | ||
""" | ||
Generate answer from context. | ||
""" | ||
completion = ( | ||
oai_client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
temperature=0, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": f"We have provided context information below. \n" | ||
f"---------------------\n" | ||
f"{context_str}" | ||
f"\n---------------------\n" | ||
f"First, say hello and that you're happy to help. \n" | ||
f"\n---------------------\n" | ||
f"Then, given this information, please answer the question: {query}", | ||
} | ||
], | ||
) | ||
.choices[0] | ||
.message.content | ||
) | ||
return completion | ||
|
||
@instrument | ||
def query(self, query: str) -> str: | ||
context_str = self.retrieve(query) | ||
completion = self.generate_completion(query, context_str) | ||
return completion | ||
|
||
class filtered_RAG_from_scratch: | ||
@instrument | ||
@context_filter(f_guardrail, 0.75, keyword_for_prompt="query") | ||
def retrieve(self, query: str) -> list: | ||
""" | ||
Retrieve relevant text from vector store. | ||
""" | ||
results = vector_store.query(query_texts=query, n_results=4) | ||
return [doc for sublist in results["documents"] for doc in sublist] | ||
|
||
@instrument | ||
def generate_completion(self, query: str, context_str: list) -> str: | ||
""" | ||
Generate answer from context. | ||
""" | ||
completion = ( | ||
oai_client.chat.completions.create( | ||
model="gpt-3.5-turbo", | ||
temperature=0, | ||
messages=[ | ||
{ | ||
"role": "user", | ||
"content": f"We have provided context information below. \n" | ||
f"---------------------\n" | ||
f"{context_str}" | ||
f"\n---------------------\n" | ||
f"Given this information, please answer the question: {query}", | ||
} | ||
], | ||
) | ||
.choices[0] | ||
.message.content | ||
) | ||
return completion | ||
|
||
@instrument | ||
def query(self, query: str) -> str: | ||
context_str = self.retrieve(query=query) | ||
completion = self.generate_completion( | ||
query=query, context_str=context_str | ||
) | ||
return completion | ||
|
||
|
||
filtered_rag = filtered_RAG_from_scratch() | ||
|
||
rag = RAG_from_scratch() | ||
|
||
tru_rag = TruCustomApp( | ||
rag, | ||
app_name="RAG", | ||
app_version="v1", | ||
feedbacks=feedbacks, | ||
) | ||
|
||
filtered_tru_rag = TruCustomApp( | ||
filtered_rag, | ||
app_name="RAG", | ||
app_version="v2", | ||
feedbacks=feedbacks, | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
__import__('pysqlite3') | ||
import sys | ||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | ||
|
||
import numpy as np | ||
from trulens.core import Feedback | ||
from trulens.core import Select | ||
from trulens.providers.openai import OpenAI as OpenAIProvider | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
provider = OpenAIProvider(model_engine="gpt-4o-mini") | ||
|
||
# Define a groundedness feedback function | ||
f_groundedness = ( | ||
Feedback( | ||
provider.groundedness_measure_with_cot_reasons, name="Groundedness" | ||
) | ||
.on(Select.RecordCalls.retrieve.rets.collect()) | ||
.on_output() | ||
) | ||
# Question/answer relevance between overall question and answer. | ||
f_answer_relevance = ( | ||
Feedback(provider.relevance_with_cot_reasons, name="Answer Relevance") | ||
.on_input() | ||
.on_output() | ||
) | ||
|
||
# Context relevance between question and each context chunk. | ||
f_context_relevance = ( | ||
Feedback( | ||
provider.context_relevance_with_cot_reasons, name="Context Relevance" | ||
) | ||
.on_input() | ||
.on(Select.RecordCalls.retrieve.rets[:]) | ||
.aggregate(np.mean) # choose a different aggregation method if you wish | ||
) | ||
|
||
feedbacks = [f_groundedness, f_answer_relevance, f_context_relevance] | ||
|
||
# note: feedback function used for guardrail must only return a score, not also reasons | ||
f_guardrail = Feedback( | ||
provider.context_relevance, name="Context Relevance" | ||
) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,8 @@ | ||
pip==24.2 | ||
openai | ||
pysqlite3-binary | ||
chromadb | ||
trulens-core @ git+https://github.com/truera/trulens#egg=trulens-core&subdirectory=src/core/ | ||
trulens-feedback @ git+https://github.com/truera/trulens#egg=trulens-feedback&subdirectory=src/feedback/ | ||
trulens-providers-openai @ git+https://github.com/truera/trulens#egg=trulens-providers-openai&subdirectory=src/providers/openai/ | ||
trulens-dashboard @ git+https://github.com/truera/trulens#egg=trulens-dashboard&subdirectory=src/dashboard/ |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
__import__('pysqlite3') | ||
import sys | ||
sys.modules['sqlite3'] = sys.modules.pop('pysqlite3') | ||
|
||
import openai | ||
import os | ||
import chromadb | ||
from chromadb.utils.embedding_functions import OpenAIEmbeddingFunction | ||
|
||
from dotenv import load_dotenv | ||
|
||
load_dotenv() | ||
|
||
uw_info = """ | ||
The University of Washington, founded in 1861 in Seattle, is a public research university | ||
with over 45,000 students across three campuses in Seattle, Tacoma, and Bothell. | ||
As the flagship institution of the six public universities in Washington state, | ||
UW encompasses over 500 buildings and 20 million square feet of space, | ||
including one of the largest library systems in the world. | ||
""" | ||
|
||
wsu_info = """ | ||
Washington State University, commonly known as WSU, founded in 1890, is a public research university in Pullman, Washington. | ||
With multiple campuses across the state, it is the state's second largest institution of higher education. | ||
WSU is known for its programs in veterinary medicine, agriculture, engineering, architecture, and pharmacy. | ||
""" | ||
|
||
seattle_info = """ | ||
Seattle, a city on Puget Sound in the Pacific Northwest, is surrounded by water, mountains and evergreen forests, and contains thousands of acres of parkland. | ||
It's home to a large tech industry, with Microsoft and Amazon headquartered in its metropolitan area. | ||
The futuristic Space Needle, a legacy of the 1962 World's Fair, is its most iconic landmark. | ||
""" | ||
|
||
starbucks_info = """ | ||
Starbucks Corporation is an American multinational chain of coffeehouses and roastery reserves headquartered in Seattle, Washington. | ||
As the world's largest coffeehouse chain, Starbucks is seen to be the main representation of the United States' second wave of coffee culture. | ||
""" | ||
|
||
newzealand_info = """ | ||
New Zealand is an island country located in the southwestern Pacific Ocean. It comprises two main landmasses—the North Island and the South Island—and over 700 smaller islands. | ||
The country is known for its stunning landscapes, ranging from lush forests and mountains to beaches and lakes. New Zealand has a rich cultural heritage, with influences from | ||
both the indigenous Māori people and European settlers. The capital city is Wellington, while the largest city is Auckland. New Zealand is also famous for its adventure tourism, | ||
including activities like bungee jumping, skiing, and hiking. | ||
""" | ||
|
||
embedding_function = OpenAIEmbeddingFunction( | ||
api_key=os.environ.get("OPENAI_API_KEY"), | ||
model_name="text-embedding-ada-002", | ||
) | ||
|
||
|
||
chroma_client = chromadb.Client() | ||
vector_store = chroma_client.get_or_create_collection( | ||
name="Washington", embedding_function=embedding_function | ||
) | ||
|
||
vector_store.add("uw_info", documents=uw_info) | ||
vector_store.add("wsu_info", documents=wsu_info) | ||
vector_store.add("seattle_info", documents=seattle_info) | ||
vector_store.add("starbucks_info", documents=starbucks_info) | ||
vector_store.add("newzealand_info", documents=newzealand_info) |