-
Notifications
You must be signed in to change notification settings - Fork 0
/
RAG_model_V1.py
79 lines (63 loc) · 2.35 KB
/
RAG_model_V1.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import os
from langchain_community.document_loaders import DirectoryLoader, TextLoader
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_openai import OpenAIEmbeddings, ChatOpenAI
from langchain_community.vectorstores import Chroma
from langchain.chains import RetrievalQA
from tqdm import tqdm
# Set API key directly (Note: This is not recommended for production use)
os.environ[
"OPENAI_API_KEY"] = "give your api key here "
# Initialize components
loader = DirectoryLoader('./datasets', glob="**/*.txt", loader_cls=TextLoader)
text_splitter = RecursiveCharacterTextSplitter(chunk_size=1000, chunk_overlap=20)
embeddings = OpenAIEmbeddings(model="text-embedding-3-small")
def process_documents():
print("Loading and processing documents...")
try:
documents = loader.load()
print(f"Loaded {len(documents)} documents")
except Exception as e:
print(f"Error loading documents: {str(e)}")
return None
print("Splitting documents into chunks...")
chunks = text_splitter.split_documents(documents)
print(f"Created {len(chunks)} chunks")
print("Creating vector store...")
vectorstore = Chroma.from_documents(
documents=chunks,
embedding=embeddings,
persist_directory="./chroma_db"
)
vectorstore.persist()
print("Vector store created and persisted")
return vectorstore
def setup_qa_chain(vectorstore):
retriever = vectorstore.as_retriever(search_kwargs={"k": 2})
llm = ChatOpenAI(model_name="gpt-3.5-turbo", temperature=0)
qa_chain = RetrievalQA.from_chain_type(
llm=llm,
chain_type="stuff",
retriever=retriever,
return_source_documents=True
)
return qa_chain
def main():
vectorstore = process_documents()
if vectorstore is None:
print("Failed to process documents. Exiting.")
return
qa_chain = setup_qa_chain(vectorstore)
while True:
question = input("Enter your question (or 'quit' to exit): ")
if question.lower() == 'quit':
break
result = qa_chain({"query": question})
print(f"\nQuestion: {question}")
print(f"Answer: {result['result']}")
print("\nSources:")
for doc in result['source_documents']:
print(f"- {doc.metadata['source']}")
print()
if __name__ == "__main__":
main()