Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Feat/chatbot core #10

Merged
merged 20 commits into from
Sep 13, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ RUN apt-get update -y \
&& poetry config virtualenvs.create false \
&& poetry install --no-root \
# Clean-up
&& rm -rf /root/.cache/pypoetry \
&& pip uninstall -y poetry virtualenv-clone virtualenv \
&& apt-get remove -y gcc libc-dev libproj-dev \
&& apt-get autoremove -y \
Expand Down
51 changes: 51 additions & 0 deletions chatbot-core/custom_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import List, Optional

import requests
from langchain.embeddings.base import Embeddings
from utils import EmbeddingModelType


@dataclass
class CustomEmbeddingsWrapper(Embeddings):
"""
This is an Embeddings model wrapper to send request to the actual
embedding models
"""

url: str
model_name: str
model_type: EmbeddingModelType
base_url: Optional[str]

def __post_init__(self):
if self.model_type in [EmbeddingModelType.SENTENCE_TRANSFORMES, EmbeddingModelType.OPENAI]:
if not (self.url and self.model_name):
raise Exception("Url or model name or both are not provided.")
elif self.model_type == EmbeddingModelType.OLLAMA:
if not (self.url and self.model_name and self.base_url):
raise Exception("Url or base_url or both are not provided.")

def embed_query(self, text: str, timeout: int = 30) -> List[float]:
"""
Sends the request to Embedding module to
embed the query to the vector representation
"""
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": text}
try:
response = requests.post(url=self.url, json=payload, timeout=timeout)
except requests.Timeout as e:
raise Exception(e)
return response.json()

def embed_documents(self, texts: List[str], timeout: int = 30) -> List[List[float]]:
"""
Sends the request to Embedding module to
embed multiple queries to the vector representation
"""
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": texts}
try:
response = requests.post(url=self.url, json=payload, timeout=timeout)
except requests.Timeout as e:
raise Exception(e)
return response.json()
67 changes: 67 additions & 0 deletions chatbot-core/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from dataclasses import dataclass, field
from uuid import uuid4

from django.conf import settings
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Distance, VectorParams

logger = logging.getLogger(__name__)


@dataclass
class QdrantDatabase:
"""Qdrant Vector Database"""

collection_name: str
host: str
port: int
db_client: QdrantClient = field(init=False)

def __post_init__(self):
"""Initialize database client"""
self.db_client = QdrantClient(host=self.host, port=self.port)

def _collection_exists(self, collection_name: str) -> bool:
"""Check if the collection in db already exists"""
try:
self.db_client.get_collection(collection_name=collection_name)
return True
except UnexpectedResponse:
return False

def set_collection(self):
"""Create the database collection"""
if not self._collection_exists(self.collection_name):
self.db_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=settings.EMBEDDING_MODEL_VECTOR_SIZE, distance=Distance.COSINE),
)
else:
logger.info(f"Collection {self.collection_name} already exists. Using the existing one.")

def store_data(self, data: list) -> None:
"""Stores data in vector db"""
point_vectors = [
{"id": str(uuid4()), "vector": v_representation, "payload": metadata} for v_representation, metadata in data
]

response = self.db_client.upsert(collection_name=self.collection_name, points=point_vectors)
return response

def data_search(
self, collection_names: list, query_vector: list, top_n_retrieval: int = 5, score_threshold: float = 0.7
):
"""Search data in the database"""
results = [
self.db_client.search(
collection_name=collection_name,
query_vector=query_vector,
top=top_n_retrieval,
score_threshold=score_threshold,
)
for collection_name in collection_names
]
# Note the results shall contain score key; sort the results using score key and get top 5 among them.
return results
61 changes: 61 additions & 0 deletions chatbot-core/doc_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from dataclasses import dataclass
from typing import List

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader


@dataclass(kw_only=True)
class DocumentLoader:
"""
Base Class for Document Loaders
"""

chunk_size: int = 200
chunk_overlap: int = 20

def _get_split_documents(self, documents: List[Document]):
"""
Splits documents into multiple chunks
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len
)

return splitter.split_documents(documents)


@dataclass
class LoaderFromText(DocumentLoader):
"""
Document loader for plain texts
"""

text: str

def create_document_chunks(self):
"""
Creates multiple documents from the input texts
"""
documents = [Document(page_content=self.text)]
doc_chunks = self._get_split_documents(documents=documents)
return doc_chunks


@dataclass
class LoaderFromWeb(DocumentLoader):
"""
Document loader for the web url
"""

url: str

def create_document_chunks(self):
"""
Creates multiple documents from the input url
"""
loader = WebBaseLoader(web_path=self.url)
docs = loader.load()
doc_chunks = self._get_split_documents(documents=docs)
return doc_chunks
162 changes: 162 additions & 0 deletions chatbot-core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Optional

from custom_embeddings import CustomEmbeddingsWrapper
from django.conf import settings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import AIMessage, HumanMessage
from langchain_community.llms.ollama import Ollama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient

logger = logging.getLogger(__name__)


@dataclass
class LLMBase:
"""LLM Base containing common methods"""

qdrant_client: QdrantClient = field(init=False)
llm_model: Any = field(init=False)
memory: Any = field(init=False)
embedding_model: CustomEmbeddingsWrapper = field(init=False)
rag_chain: Optional[Any] = None

def __post_init__(self, mem_key: str = "chat_history", conversation_max_window: int = 3):
self.llm_model = None
self.qdrant_client = None
self.memory = None

try:
self.qdrant_client = QdrantClient(host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT)
except Exception as e:
raise Exception(f"Qdrant client is not properly setup. {str(e)}")
self.memory = ConversationBufferWindowMemory(k=conversation_max_window, memory_key=mem_key, return_messages=True)

self.embedding_model = CustomEmbeddingsWrapper(
url=settings.EMBEDDING_MODEL_URL,
model_name=settings.EMBEDDING_MODEL_NAME,
model_type=settings.EMBEDDING_MODEL_TYPE,
base_url=settings.OLLAMA_EMBEDDING_MODEL_BASE_URL,
)

def _system_prompt_for_retrieval(self):
"""System prompt for information retrieval"""
return (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)

def _system_prompt_for_response(self):
"""
System prompt for response generation
"""
system_prompt = """
You are an assistant for question-answering tasks.\n,
Use the following retrieved context to answer the question.\n,
If you don't get the answer from the provided context, \n,
say that 'I don't know. How can I help with the office related queries ?'
\n\n,
Context: {context}
"""
return system_prompt

def get_prompt_template_for_retrieval(self):
"""Get the prompt template"""
system_prompt = self._system_prompt_for_retrieval()
context_prompt_template = ChatPromptTemplate.from_messages(
[("system", system_prompt), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}")]
)
return context_prompt_template

def get_prompt_template_for_response(self):
"""Get the prompt template for response generation"""
system_prompt = self._system_prompt_for_response()
llm_response_prompt = ChatPromptTemplate.from_messages(
[("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}")]
)
return llm_response_prompt

def get_db_retriever(self, collection_name: str, top_k_items: int = 5, score_threshold: float = 0.5):
"""Get the database retriever"""
db_retriever = QdrantVectorStore(
client=self.qdrant_client, collection_name=collection_name, embedding=self.embedding_model
)
retriever = db_retriever.as_retriever(
search_type="similarity_score_threshold", search_kwargs={"k": top_k_items, "score_threshold": score_threshold}
)
return retriever

def create_chain(self, db_collection_name: str):
"""Creates a llm chain"""
if not self.llm_model:
raise Exception("The LLM model is not loaded.")

context_prompt_template = self.get_prompt_template_for_retrieval()
response_prompt_template = self.get_prompt_template_for_response()

retriever = self.get_db_retriever(collection_name=db_collection_name)

history_aware_retriever = create_history_aware_retriever(self.llm_model, retriever, context_prompt_template)

chat_response_chain = create_stuff_documents_chain(self.llm_model, response_prompt_template)

rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

response = self.rag_chain.invoke({"input": query, "chat_history": self.get_message_history()["chat_history"]})
self.memory.chat_memory.add_message(HumanMessage(content=query))
self.memory.chat_memory.add_message(AIMessage(content=response["answer"]))
return response["answer"] if "answer" in response else ""

def get_message_history(self):
"""
Returns the historical conversational data
"""
return self.memory.load_memory_variables({})


@dataclass
class OpenAIHandler(LLMBase):
"""LLM handler using OpenAI for RAG"""

temperature: float = 0.2

def __post_init__(self):
super().__post_init__()
try:
self.llm_model = ChatOpenAI(model=settings.LLM_MODEL_NAME, temperature=self.temperature)
except Exception as e:
raise Exception(f"OpenAI LLM model is not successfully loaded. {str(e)}")


@dataclass
class OllamaHandler(LLMBase):
"""LLM Handler using Ollama for RAG"""

temperature: float = 0.2

def __post_init__(self):
super().__post_init__()
try:
self.llm_model = Ollama(
model=settings.LLM_MODEL_NAME, base_url=settings.LLM_OLLAMA_BASE_URL, temperature=self.temperature
)
except Exception as e:
raise Exception(f"Ollama LLM model is not successfully loaded. {str(e)}")
27 changes: 27 additions & 0 deletions chatbot-core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re
from enum import Enum


class EmbeddingModelType(Enum):
"""Embedding Model Types"""

SENTENCE_TRANSFORMES = 1
OLLAMA = 2
OPENAI = 3


class LLMType(Enum):
"""LLM Types"""

OLLAMA = 1
OPENAI = 2


def preprocess_text(texts: list[str]) -> list[str]:
"""
Preprocessing of the texts
"""
pattern = r"\s+"

results = [re.sub(pattern, "", text) for text in texts]
return results
Loading
Loading