-
Notifications
You must be signed in to change notification settings - Fork 0
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add the core chatbot components * Update the chatbot packages * Update in db module * Update Env vars * Update the python version, pkgs * Fix some params and Add the embeding wrapper + Update in LLM module * Update the docker compose file * Note add regarding embedding model
- Loading branch information
1 parent
ef04ede
commit 7b4a29f
Showing
11 changed files
with
2,571 additions
and
14 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,51 @@ | ||
from dataclasses import dataclass | ||
from typing import List, Optional | ||
|
||
import requests | ||
from langchain.embeddings.base import Embeddings | ||
from utils import EmbeddingModelType | ||
|
||
|
||
@dataclass | ||
class CustomEmbeddingsWrapper(Embeddings): | ||
""" | ||
This is an Embeddings model wrapper to send request to the actual | ||
embedding models | ||
""" | ||
|
||
url: str | ||
model_name: str | ||
model_type: EmbeddingModelType | ||
base_url: Optional[str] | ||
|
||
def __post_init__(self): | ||
if self.model_type in [EmbeddingModelType.SENTENCE_TRANSFORMES, EmbeddingModelType.OPENAI]: | ||
if not (self.url and self.model_name): | ||
raise Exception("Url or model name or both are not provided.") | ||
elif self.model_type == EmbeddingModelType.OLLAMA: | ||
if not (self.url and self.model_name and self.base_url): | ||
raise Exception("Url or base_url or both are not provided.") | ||
|
||
def embed_query(self, text: str, timeout: int = 30) -> List[float]: | ||
""" | ||
Sends the request to Embedding module to | ||
embed the query to the vector representation | ||
""" | ||
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": text} | ||
try: | ||
response = requests.post(url=self.url, json=payload, timeout=timeout) | ||
except requests.Timeout as e: | ||
raise Exception(e) | ||
return response.json() | ||
|
||
def embed_documents(self, texts: List[str], timeout: int = 30) -> List[List[float]]: | ||
""" | ||
Sends the request to Embedding module to | ||
embed multiple queries to the vector representation | ||
""" | ||
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": texts} | ||
try: | ||
response = requests.post(url=self.url, json=payload, timeout=timeout) | ||
except requests.Timeout as e: | ||
raise Exception(e) | ||
return response.json() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
import logging | ||
from dataclasses import dataclass, field | ||
from uuid import uuid4 | ||
|
||
from django.conf import settings | ||
from qdrant_client import QdrantClient | ||
from qdrant_client.http.exceptions import UnexpectedResponse | ||
from qdrant_client.http.models import Distance, VectorParams | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class QdrantDatabase: | ||
"""Qdrant Vector Database""" | ||
|
||
collection_name: str | ||
host: str | ||
port: int | ||
db_client: QdrantClient = field(init=False) | ||
|
||
def __post_init__(self): | ||
"""Initialize database client""" | ||
self.db_client = QdrantClient(host=self.host, port=self.port) | ||
|
||
def _collection_exists(self, collection_name: str) -> bool: | ||
"""Check if the collection in db already exists""" | ||
try: | ||
self.db_client.get_collection(collection_name=collection_name) | ||
return True | ||
except UnexpectedResponse: | ||
return False | ||
|
||
def set_collection(self): | ||
"""Create the database collection""" | ||
if not self._collection_exists(self.collection_name): | ||
self.db_client.create_collection( | ||
collection_name=self.collection_name, | ||
vectors_config=VectorParams(size=settings.EMBEDDING_MODEL_VECTOR_SIZE, distance=Distance.COSINE), | ||
) | ||
else: | ||
logger.info(f"Collection {self.collection_name} already exists. Using the existing one.") | ||
|
||
def store_data(self, data: list) -> None: | ||
"""Stores data in vector db""" | ||
point_vectors = [ | ||
{"id": str(uuid4()), "vector": v_representation, "payload": metadata} for v_representation, metadata in data | ||
] | ||
|
||
response = self.db_client.upsert(collection_name=self.collection_name, points=point_vectors) | ||
return response | ||
|
||
def data_search( | ||
self, collection_names: list, query_vector: list, top_n_retrieval: int = 5, score_threshold: float = 0.7 | ||
): | ||
"""Search data in the database""" | ||
results = [ | ||
self.db_client.search( | ||
collection_name=collection_name, | ||
query_vector=query_vector, | ||
top=top_n_retrieval, | ||
score_threshold=score_threshold, | ||
) | ||
for collection_name in collection_names | ||
] | ||
# Note the results shall contain score key; sort the results using score key and get top 5 among them. | ||
return results |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
from dataclasses import dataclass | ||
from typing import List | ||
|
||
from langchain.schema import Document | ||
from langchain.text_splitter import RecursiveCharacterTextSplitter | ||
from langchain_community.document_loaders import WebBaseLoader | ||
|
||
|
||
@dataclass(kw_only=True) | ||
class DocumentLoader: | ||
""" | ||
Base Class for Document Loaders | ||
""" | ||
|
||
chunk_size: int = 200 | ||
chunk_overlap: int = 20 | ||
|
||
def _get_split_documents(self, documents: List[Document]): | ||
""" | ||
Splits documents into multiple chunks | ||
""" | ||
splitter = RecursiveCharacterTextSplitter( | ||
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len | ||
) | ||
|
||
return splitter.split_documents(documents) | ||
|
||
|
||
@dataclass | ||
class LoaderFromText(DocumentLoader): | ||
""" | ||
Document loader for plain texts | ||
""" | ||
|
||
text: str | ||
|
||
def create_document_chunks(self): | ||
""" | ||
Creates multiple documents from the input texts | ||
""" | ||
documents = [Document(page_content=self.text)] | ||
doc_chunks = self._get_split_documents(documents=documents) | ||
return doc_chunks | ||
|
||
|
||
@dataclass | ||
class LoaderFromWeb(DocumentLoader): | ||
""" | ||
Document loader for the web url | ||
""" | ||
|
||
url: str | ||
|
||
def create_document_chunks(self): | ||
""" | ||
Creates multiple documents from the input url | ||
""" | ||
loader = WebBaseLoader(web_path=self.url) | ||
docs = loader.load() | ||
doc_chunks = self._get_split_documents(documents=docs) | ||
return doc_chunks |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,162 @@ | ||
import logging | ||
from dataclasses import dataclass, field | ||
from typing import Any, Optional | ||
|
||
from custom_embeddings import CustomEmbeddingsWrapper | ||
from django.conf import settings | ||
from langchain.chains.combine_documents import create_stuff_documents_chain | ||
from langchain.chains.history_aware_retriever import create_history_aware_retriever | ||
from langchain.chains.retrieval import create_retrieval_chain | ||
from langchain.memory import ConversationBufferWindowMemory | ||
from langchain.schema import AIMessage, HumanMessage | ||
from langchain_community.llms.ollama import Ollama | ||
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder | ||
from langchain_openai import ChatOpenAI | ||
from langchain_qdrant import QdrantVectorStore | ||
from qdrant_client import QdrantClient | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
@dataclass | ||
class LLMBase: | ||
"""LLM Base containing common methods""" | ||
|
||
qdrant_client: QdrantClient = field(init=False) | ||
llm_model: Any = field(init=False) | ||
memory: Any = field(init=False) | ||
embedding_model: CustomEmbeddingsWrapper = field(init=False) | ||
rag_chain: Optional[Any] = None | ||
|
||
def __post_init__(self, mem_key: str = "chat_history", conversation_max_window: int = 3): | ||
self.llm_model = None | ||
self.qdrant_client = None | ||
self.memory = None | ||
|
||
try: | ||
self.qdrant_client = QdrantClient(host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT) | ||
except Exception as e: | ||
raise Exception(f"Qdrant client is not properly setup. {str(e)}") | ||
self.memory = ConversationBufferWindowMemory(k=conversation_max_window, memory_key=mem_key, return_messages=True) | ||
|
||
self.embedding_model = CustomEmbeddingsWrapper( | ||
url=settings.EMBEDDING_MODEL_URL, | ||
model_name=settings.EMBEDDING_MODEL_NAME, | ||
model_type=settings.EMBEDDING_MODEL_TYPE, | ||
base_url=settings.OLLAMA_EMBEDDING_MODEL_BASE_URL, | ||
) | ||
|
||
def _system_prompt_for_retrieval(self): | ||
"""System prompt for information retrieval""" | ||
return ( | ||
"Given a chat history and the latest user question " | ||
"which might reference context in the chat history, " | ||
"formulate a standalone question which can be understood " | ||
"without the chat history. Do NOT answer the question, " | ||
"just reformulate it if needed and otherwise return it as is." | ||
) | ||
|
||
def _system_prompt_for_response(self): | ||
""" | ||
System prompt for response generation | ||
""" | ||
system_prompt = """ | ||
You are an assistant for question-answering tasks.\n, | ||
Use the following retrieved context to answer the question.\n, | ||
If you don't get the answer from the provided context, \n, | ||
say that 'I don't know. How can I help with the office related queries ?' | ||
\n\n, | ||
Context: {context} | ||
""" | ||
return system_prompt | ||
|
||
def get_prompt_template_for_retrieval(self): | ||
"""Get the prompt template""" | ||
system_prompt = self._system_prompt_for_retrieval() | ||
context_prompt_template = ChatPromptTemplate.from_messages( | ||
[("system", system_prompt), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}")] | ||
) | ||
return context_prompt_template | ||
|
||
def get_prompt_template_for_response(self): | ||
"""Get the prompt template for response generation""" | ||
system_prompt = self._system_prompt_for_response() | ||
llm_response_prompt = ChatPromptTemplate.from_messages( | ||
[("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}")] | ||
) | ||
return llm_response_prompt | ||
|
||
def get_db_retriever(self, collection_name: str, top_k_items: int = 5, score_threshold: float = 0.5): | ||
"""Get the database retriever""" | ||
db_retriever = QdrantVectorStore( | ||
client=self.qdrant_client, collection_name=collection_name, embedding=self.embedding_model | ||
) | ||
retriever = db_retriever.as_retriever( | ||
search_type="similarity_score_threshold", search_kwargs={"k": top_k_items, "score_threshold": score_threshold} | ||
) | ||
return retriever | ||
|
||
def create_chain(self, db_collection_name: str): | ||
"""Creates a llm chain""" | ||
if not self.llm_model: | ||
raise Exception("The LLM model is not loaded.") | ||
|
||
context_prompt_template = self.get_prompt_template_for_retrieval() | ||
response_prompt_template = self.get_prompt_template_for_response() | ||
|
||
retriever = self.get_db_retriever(collection_name=db_collection_name) | ||
|
||
history_aware_retriever = create_history_aware_retriever(self.llm_model, retriever, context_prompt_template) | ||
|
||
chat_response_chain = create_stuff_documents_chain(self.llm_model, response_prompt_template) | ||
|
||
rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain) | ||
return rag_chain | ||
|
||
def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME): | ||
""" | ||
Executes the chain | ||
""" | ||
if not self.rag_chain: | ||
self.rag_chain = self.create_chain(db_collection_name=db_collection_name) | ||
|
||
response = self.rag_chain.invoke({"input": query, "chat_history": self.get_message_history()["chat_history"]}) | ||
self.memory.chat_memory.add_message(HumanMessage(content=query)) | ||
self.memory.chat_memory.add_message(AIMessage(content=response["answer"])) | ||
return response["answer"] if "answer" in response else "" | ||
|
||
def get_message_history(self): | ||
""" | ||
Returns the historical conversational data | ||
""" | ||
return self.memory.load_memory_variables({}) | ||
|
||
|
||
@dataclass | ||
class OpenAIHandler(LLMBase): | ||
"""LLM handler using OpenAI for RAG""" | ||
|
||
temperature: float = 0.2 | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
try: | ||
self.llm_model = ChatOpenAI(model=settings.LLM_MODEL_NAME, temperature=self.temperature) | ||
except Exception as e: | ||
raise Exception(f"OpenAI LLM model is not successfully loaded. {str(e)}") | ||
|
||
|
||
@dataclass | ||
class OllamaHandler(LLMBase): | ||
"""LLM Handler using Ollama for RAG""" | ||
|
||
temperature: float = 0.2 | ||
|
||
def __post_init__(self): | ||
super().__post_init__() | ||
try: | ||
self.llm_model = Ollama( | ||
model=settings.LLM_MODEL_NAME, base_url=settings.LLM_OLLAMA_BASE_URL, temperature=self.temperature | ||
) | ||
except Exception as e: | ||
raise Exception(f"Ollama LLM model is not successfully loaded. {str(e)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,27 @@ | ||
import re | ||
from enum import Enum | ||
|
||
|
||
class EmbeddingModelType(Enum): | ||
"""Embedding Model Types""" | ||
|
||
SENTENCE_TRANSFORMES = 1 | ||
OLLAMA = 2 | ||
OPENAI = 3 | ||
|
||
|
||
class LLMType(Enum): | ||
"""LLM Types""" | ||
|
||
OLLAMA = 1 | ||
OPENAI = 2 | ||
|
||
|
||
def preprocess_text(texts: list[str]) -> list[str]: | ||
""" | ||
Preprocessing of the texts | ||
""" | ||
pattern = r"\s+" | ||
|
||
results = [re.sub(pattern, "", text) for text in texts] | ||
return results |
Oops, something went wrong.