-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
1 changed file
with
37 additions
and
18 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,10 +1,12 @@ | ||
# -*- coding: utf-8 | ||
# Reinaldo Chaves ([email protected]) | ||
# Este projeto implementa um sistema de Retrieval-Augmented Generation (RAG) conversacional | ||
# usando Streamlit, LangChain, e modelos de linguagem de grande escala - para entrevistar conteúdo de URLs | ||
# Geração de respostas usando o modelo llama-3.2-90b-text-preview da Meta | ||
# Embeddings de texto usando o modelo all-MiniLM-L6-v2 do Hugging Face | ||
|
||
# Author: Reinaldo Chaves ([email protected]) | ||
# This project implements a conversational Retrieval-Augmented Generation (RAG) system | ||
# using Streamlit, LangChain, and large language models to interview content from URLs | ||
# Response generation uses the llama-3.2-90b-text-preview model from Meta | ||
# Text embeddings use the all-MiniLM-L6-v2 model from Hugging Face | ||
# I am grateful for Krish C Naik's classes (https://www.youtube.com/user/krishnaik06) | ||
|
||
# Import necessary libraries | ||
import streamlit as st | ||
from langchain.chains import create_history_aware_retriever, create_retrieval_chain | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
|
@@ -26,10 +28,11 @@ | |
from langchain_groq import ChatGroq | ||
from pydantic import Field | ||
|
||
# Configurar o tema para dark | ||
# Configure Streamlit page settings | ||
st.set_page_config(page_title="RAG Q&A Conversacional", layout="wide", initial_sidebar_state="expanded", page_icon="🤖", menu_items=None) | ||
|
||
# Aplicar o tema dark com CSS | ||
# Apply dark theme using custom CSS | ||
# This section includes CSS to style various Streamlit components for a dark theme | ||
st.markdown(""" | ||
<style> | ||
/* Estilo global */ | ||
|
@@ -135,7 +138,7 @@ | |
</style> | ||
""", unsafe_allow_html=True) | ||
|
||
# Sidebar com orientações | ||
# Sidebar with guidelines | ||
st.sidebar.markdown("<h2 class='orange-title'>Orientações</h2>", unsafe_allow_html=True) | ||
st.sidebar.markdown(""" | ||
* Se encontrar erros de processamento, reinicie com F5. | ||
|
@@ -162,14 +165,15 @@ | |
Este aplicativo foi desenvolvido por Reinaldo Chaves. Para mais informações, contribuições e feedback, visite o [repositório do projeto no GitHub](https://github.com/reichaves/entrevista_url_llama3). | ||
""") | ||
|
||
# Main title and description | ||
st.markdown("<h1 class='yellow-title'>Chatbot com modelos opensource - entrevista URLs ✏️</h1>", unsafe_allow_html=True) | ||
st.write("Insira uma URL e converse com o conteúdo dela - aqui é usado o modelo de LLM llama-3.2-90b-text-preview e a plataforma de embeddings é all-MiniLM-L6-v2") | ||
|
||
# Solicitar as chaves de API | ||
# Request API keys from the user | ||
groq_api_key = st.text_input("Insira sua chave de API Groq (depois pressione Enter):", type="password") | ||
huggingface_api_token = st.text_input("Insira seu token de API HuggingFace (depois pressione Enter):", type="password") | ||
|
||
# Wrapper personalizado para ChatGroq com rate limiting | ||
# Custom wrapper for ChatGroq with rate limiting | ||
class RateLimitedChatGroq(BaseChatModel): | ||
llm: ChatGroq = Field(default_factory=lambda: ChatGroq()) | ||
|
||
|
@@ -199,34 +203,37 @@ def _generate(self, messages, stop=None, run_manager=None, **kwargs) -> ChatResu | |
def _llm_type(self): | ||
return "rate_limited_chat_groq" | ||
|
||
# Main application logic | ||
if groq_api_key and huggingface_api_token: | ||
# Configurar o token da API do Hugging Face | ||
# Set API tokens as environment variables | ||
os.environ["HUGGINGFACEHUB_API_TOKEN"] = huggingface_api_token | ||
|
||
# Configurar a chave de API do Groq no ambiente | ||
os.environ["GROQ_API_KEY"] = groq_api_key | ||
|
||
# Inicializar o modelo de linguagem e embeddings | ||
# Initialize language model and embeddings | ||
rate_limited_llm = RateLimitedChatGroq(groq_api_key=groq_api_key, model_name="llama-3.2-90b-text-preview", temperature=0) | ||
embeddings = HuggingFaceEmbeddings(model_name="all-MiniLM-L6-v2") | ||
|
||
# Create a session ID for chat history management | ||
session_id = st.text_input("Session ID", value="default_session") | ||
|
||
# Initialize session state for storing chat history | ||
if 'store' not in st.session_state: | ||
st.session_state.store = {} | ||
|
||
# Get URL input from user | ||
url = st.text_input("Insira a URL para análise:") | ||
|
||
if url: | ||
try: | ||
# Fetch and process the webpage content | ||
response = requests.get(url) | ||
response.raise_for_status() | ||
soup = BeautifulSoup(response.text, 'html.parser') | ||
|
||
# Extract text from the webpage | ||
text = soup.get_text(separator='\n', strip=True) | ||
|
||
# Limit the text to a certain number of characters (e.g., 50,000) | ||
# Limit the text to a maximum number of characters | ||
max_chars = 50000 | ||
if len(text) > max_chars: | ||
text = text[:max_chars] | ||
|
@@ -235,16 +242,19 @@ def _llm_type(self): | |
# Create a Document object | ||
document = Document(page_content=text, metadata={"source": url}) | ||
|
||
# Split the document into smaller chunks | ||
text_splitter = RecursiveCharacterTextSplitter(chunk_size=5000, chunk_overlap=500) | ||
splits = text_splitter.split_documents([document]) | ||
|
||
# Create FAISS vector store | ||
# Create FAISS vector store for efficient similarity search | ||
vectorstore = FAISS.from_documents(splits, embeddings) | ||
|
||
st.success(f"Processado {len(splits)} pedaços de documentos (chunks) da URL.") | ||
|
||
# Set up the retriever | ||
retriever = vectorstore.as_retriever() | ||
|
||
# Define the system prompt for contextualizing questions | ||
contextualize_q_system_prompt = ( | ||
"Given a chat history and the latest user question " | ||
"which might reference context in the chat history, " | ||
|
@@ -258,8 +268,10 @@ def _llm_type(self): | |
("human", "{input}"), | ||
]) | ||
|
||
# Create a history-aware retriever | ||
history_aware_retriever = create_history_aware_retriever(rate_limited_llm, retriever, contextualize_q_prompt) | ||
|
||
# Define the main system prompt for the chatbot | ||
system_prompt = ( | ||
"Você é um assistente especializado em analisar conteúdo de páginas web. " | ||
"Sempre coloque no final das respostas: 'Todas as informações devem ser checadas com a(s) fonte(s) original(ais)'" | ||
|
@@ -283,27 +295,32 @@ def _llm_type(self): | |
"Sempre termine as respostas com: 'Todas as informações precisam ser checadas com as fontes das informações'." | ||
) | ||
|
||
# Create the question-answering prompt template | ||
qa_prompt = ChatPromptTemplate.from_messages([ | ||
("system", system_prompt), | ||
MessagesPlaceholder("chat_history"), | ||
("human", "{input}"), | ||
]) | ||
|
||
# Create the question-answering chain | ||
question_answer_chain = create_stuff_documents_chain(rate_limited_llm, qa_prompt) | ||
rag_chain = create_retrieval_chain(history_aware_retriever, question_answer_chain) | ||
|
||
# Function to get or create session history | ||
def get_session_history(session: str) -> BaseChatMessageHistory: | ||
if session not in st.session_state.store: | ||
st.session_state.store[session] = ChatMessageHistory() | ||
return st.session_state.store[session] | ||
|
||
# Create a conversational RAG chain with message history | ||
conversational_rag_chain = RunnableWithMessageHistory( | ||
rag_chain, get_session_history, | ||
input_messages_key="input", | ||
history_messages_key="chat_history", | ||
output_messages_key="answer" | ||
) | ||
|
||
# Get user input and process the question | ||
user_input = st.text_input("Sua pergunta:") | ||
if user_input: | ||
with st.spinner("Processando sua pergunta..."): | ||
|
@@ -313,10 +330,12 @@ def get_session_history(session: str) -> BaseChatMessageHistory: | |
config={"configurable": {"session_id": session_id}}, | ||
) | ||
st.write("Assistente:", response['answer']) | ||
|
||
|
||
# Display chat history | ||
with st.expander("Ver histórico do chat"): | ||
for message in session_history.messages: | ||
st.write(f"**{message.type}:** {message.content}") | ||
# Error handling | ||
except requests.RequestException as e: | ||
st.error(f"Erro ao acessar a URL: {str(e)}") | ||
except Exception as e: | ||
|