Skip to content

Commit

Permalink
ran black
Browse files Browse the repository at this point in the history
  • Loading branch information
Daethyra committed Feb 11, 2024
1 parent 3fd06a3 commit 9d83aa6
Show file tree
Hide file tree
Showing 2 changed files with 56 additions and 36 deletions.
57 changes: 32 additions & 25 deletions freestream/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,12 @@
from langchain.memory import ConversationBufferMemory
from langchain_community.chat_message_histories import StreamlitChatMessageHistory
from langchain.chains import ConversationalRetrievalChain
from utility_funcs import configure_retriever, StreamHandler, PrintRetrievalHandler, set_llm
from utility_funcs import (
configure_retriever,
StreamHandler,
PrintRetrievalHandler,
set_llm,
)

# Initialize LangSmith tracing
os.environ["LANGCHAIN_TRACING_V2"] = "true"
Expand All @@ -24,7 +29,7 @@
label="Upload a PDF or text file",
type=["pdf", "doc", "docx", "txt"],
help="Types supported: pdf, doc, docx, txt",
accept_multiple_files=True
accept_multiple_files=True,
)
if not uploaded_files:
st.info("Please upload documents to continue.")
Expand All @@ -34,41 +39,41 @@

# Setup memory for contextual conversation
msgs = StreamlitChatMessageHistory()
memory = ConversationBufferMemory(memory_key="chat_history", chat_memory=msgs, return_messages=True)
memory = ConversationBufferMemory(
memory_key="chat_history", chat_memory=msgs, return_messages=True
)

# Create a dictionary with keys to chat model classes
model_names = {
"ChatOpenAI GPT-3.5 Turbo": ChatOpenAI( # Define a dictionary entry for the "ChatOpenAI GPT-3.5 Turbo" model
model_name="gpt-3.5-turbo-0125", # Set the OpenAI model name
openai_api_key=st.secrets.OPENAI.openai_api_key, # Set the OpenAI API key from the Streamlit secrets manager
temperature=0.7, # Set the temperature for the model's responses
streaming=True # Enable streaming responses for the model
),
"ChatOpenAI GPT-3.5 Turbo": ChatOpenAI( # Define a dictionary entry for the "ChatOpenAI GPT-3.5 Turbo" model
model_name="gpt-3.5-turbo-0125", # Set the OpenAI model name
openai_api_key=st.secrets.OPENAI.openai_api_key, # Set the OpenAI API key from the Streamlit secrets manager
temperature=0.7, # Set the temperature for the model's responses
streaming=True, # Enable streaming responses for the model
),
}

# Create a dropdown menu for selecting a chat model
selected_model = st.selectbox(
label="Choose your chat model:", # Set the label for the dropdown menu
options=list(model_names.keys()), # Set the available model options
key="model_selector", # Set a unique key for the dropdown menu
on_change=set_llm # Set the callback function
label="Choose your chat model:", # Set the label for the dropdown menu
options=list(model_names.keys()), # Set the available model options
key="model_selector", # Set a unique key for the dropdown menu
on_change=set_llm, # Set the callback function
)

# Load the selected model dynamically
llm = model_names[selected_model] # Get the selected model \
# from the `model_names` dictionary
llm = model_names[
selected_model
] # Get the selected model from the `model_names` dictionary

# Create a chain that ties everything together
qa_chain = ConversationalRetrievalChain.from_llm(
llm,
retriever=retriever,
memory=memory,
verbose=True
llm, retriever=retriever, memory=memory, verbose=True
)

# if the length of messages is 0, or when the user \
# clicks the clear button,
# show a default message from the AI
# clicks the clear button,
# show a default message from the AI
if len(msgs.messages) == 0 or st.sidebar.button("Clear message history"):
msgs.clear()
# show a default message from the AI
Expand All @@ -82,14 +87,16 @@
# Display user input field and enter button
if user_query := st.chat_input(placeholder="Ask me anything!"):
st.chat_message("user").write(user_query)

# Display assistant response
with st.chat_message("assistant"):
# Check for the presence of the "messages" key in session state
if 'messages' not in st.session_state:
if "messages" not in st.session_state:
st.session_state.messages = []

retrieval_handler = PrintRetrievalHandler(st.container())
stream_handler = StreamHandler(st.empty())
response = qa_chain.run(user_query, callbacks=[retrieval_handler, stream_handler])
st.toast('Success!', icon="✅")
response = qa_chain.run(
user_query, callbacks=[retrieval_handler, stream_handler]
)
st.toast("Success!", icon="✅")
35 changes: 24 additions & 11 deletions freestream/utility_funcs.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,11 @@
logging.basicConfig(level=logging.WARNING, stream=sys.stdout)
logger = logging.getLogger(__name__)

@st.cache_resource(ttl="1h") # Cache the resource

@st.cache_resource(ttl="1h") # Cache the resource
def configure_retriever(uploaded_files):
"""
This function configures and returns a retriever object for a given list of uploaded files.
This function configures and returns a retriever object for a given list of uploaded files.
The function performs the following steps:
1. Reads the documents from the uploaded files.
Expand Down Expand Up @@ -48,14 +49,18 @@ def configure_retriever(uploaded_files):
chunks = text_splitter.split_documents(docs)

# Create embeddings and store in vectordb
# quickly create a GPU detection line for model_kwargs
# quickly create a GPU detection line for model_kwargs
model_kwargs = {"device": "cuda" if torch.cuda.is_available() else "cpu"}
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2", model_kwargs=model_kwargs)
embeddings = HuggingFaceEmbeddings(
model_name="all-MiniLM-L6-v2", model_kwargs=model_kwargs
)
vectordb = FAISS.from_documents(chunks, embeddings)

# Define retriever
retriever = vectordb.as_retriever(search_type="mmr", search_kwargs={"k": 3, "fetch_k": 7})

retriever = vectordb.as_retriever(
search_type="mmr", search_kwargs={"k": 3, "fetch_k": 7}
)

return retriever


Expand All @@ -70,7 +75,10 @@ class StreamHandler(BaseCallbackHandler):
text (str): The text that has been generated by the model.
run_id_ignore_token (str): The run ID for ignoring the rephrased question as output.
"""
def __init__(self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""):

def __init__(
self, container: st.delta_generator.DeltaGenerator, initial_text: str = ""
):
"""
Initialize the StreamHandler object.
Expand Down Expand Up @@ -117,13 +125,14 @@ class PrintRetrievalHandler(BaseCallbackHandler):
"""
A callback handler for printing the context retrieval status.
This handler updates the status of the retrieval process, including the question, document sources,
This handler updates the status of the retrieval process, including the question, document sources,
and page contents. It also changes the status label and state according to the retrieval process.
Attributes:
container (Container): The container object that contains the status object.
status (Status): The status object for updating the retrieval process status.
"""

def __init__(self, container):
"""
Initialize the PrintRetrievalHandler object.
Expand Down Expand Up @@ -163,6 +172,7 @@ def on_retriever_end(self, documents, **kwargs):
self.status.markdown(doc.page_content)
self.status.update(state="complete")


# Define a callback function for when a model is selected
def set_llm():
"""
Expand All @@ -182,13 +192,16 @@ def set_llm():
"""
# Set the model in session state
st.session_state.llm = model_names[selected_model]

# Show an alert based on what model was selected
if st.session_state.model_selector == model_names["ChatOpenAI GPT-3.5 Turbo"]:
st.warning(body="Switched to ChatGPT 3.5-Turbo!", icon="⚠️")

# Add more if statements for each added model
# if st.session_state.model_selector == model_names["GPT-4"]:
# ...
else:
st.warning(body="Failed to change model! \nPlease contact the website builder.", icon="⚠️")
st.warning(
body="Failed to change model! \nPlease contact the website builder.",
icon="⚠️",
)

0 comments on commit 9d83aa6

Please sign in to comment.