Skip to content

Commit

Permalink
Feat/chatbot core (#10)
Browse files Browse the repository at this point in the history
Add the core chatbot components

* Update the chatbot packages
* Update in db module
* Update Env vars
* Update the python version, pkgs
* Fix some params and Add the embeding wrapper + Update in LLM module
* Update the docker compose file
* Note add regarding embedding model
  • Loading branch information
ranjan-stha authored and sudan45 committed Sep 13, 2024
1 parent ef04ede commit 7b4a29f
Show file tree
Hide file tree
Showing 11 changed files with 2,571 additions and 14 deletions.
1 change: 1 addition & 0 deletions Dockerfile
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ RUN apt-get update -y \
&& poetry config virtualenvs.create false \
&& poetry install --no-root \
# Clean-up
&& rm -rf /root/.cache/pypoetry \
&& pip uninstall -y poetry virtualenv-clone virtualenv \
&& apt-get remove -y gcc libc-dev libproj-dev \
&& apt-get autoremove -y \
Expand Down
51 changes: 51 additions & 0 deletions chatbot-core/custom_embeddings.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from dataclasses import dataclass
from typing import List, Optional

import requests
from langchain.embeddings.base import Embeddings
from utils import EmbeddingModelType


@dataclass
class CustomEmbeddingsWrapper(Embeddings):
"""
This is an Embeddings model wrapper to send request to the actual
embedding models
"""

url: str
model_name: str
model_type: EmbeddingModelType
base_url: Optional[str]

def __post_init__(self):
if self.model_type in [EmbeddingModelType.SENTENCE_TRANSFORMES, EmbeddingModelType.OPENAI]:
if not (self.url and self.model_name):
raise Exception("Url or model name or both are not provided.")
elif self.model_type == EmbeddingModelType.OLLAMA:
if not (self.url and self.model_name and self.base_url):
raise Exception("Url or base_url or both are not provided.")

def embed_query(self, text: str, timeout: int = 30) -> List[float]:
"""
Sends the request to Embedding module to
embed the query to the vector representation
"""
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": text}
try:
response = requests.post(url=self.url, json=payload, timeout=timeout)
except requests.Timeout as e:
raise Exception(e)
return response.json()

def embed_documents(self, texts: List[str], timeout: int = 30) -> List[List[float]]:
"""
Sends the request to Embedding module to
embed multiple queries to the vector representation
"""
payload = {"type_model": self.model_type, "name_model": self.model_name, "texts": texts}
try:
response = requests.post(url=self.url, json=payload, timeout=timeout)
except requests.Timeout as e:
raise Exception(e)
return response.json()
67 changes: 67 additions & 0 deletions chatbot-core/database.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,67 @@
import logging
from dataclasses import dataclass, field
from uuid import uuid4

from django.conf import settings
from qdrant_client import QdrantClient
from qdrant_client.http.exceptions import UnexpectedResponse
from qdrant_client.http.models import Distance, VectorParams

logger = logging.getLogger(__name__)


@dataclass
class QdrantDatabase:
"""Qdrant Vector Database"""

collection_name: str
host: str
port: int
db_client: QdrantClient = field(init=False)

def __post_init__(self):
"""Initialize database client"""
self.db_client = QdrantClient(host=self.host, port=self.port)

def _collection_exists(self, collection_name: str) -> bool:
"""Check if the collection in db already exists"""
try:
self.db_client.get_collection(collection_name=collection_name)
return True
except UnexpectedResponse:
return False

def set_collection(self):
"""Create the database collection"""
if not self._collection_exists(self.collection_name):
self.db_client.create_collection(
collection_name=self.collection_name,
vectors_config=VectorParams(size=settings.EMBEDDING_MODEL_VECTOR_SIZE, distance=Distance.COSINE),
)
else:
logger.info(f"Collection {self.collection_name} already exists. Using the existing one.")

def store_data(self, data: list) -> None:
"""Stores data in vector db"""
point_vectors = [
{"id": str(uuid4()), "vector": v_representation, "payload": metadata} for v_representation, metadata in data
]

response = self.db_client.upsert(collection_name=self.collection_name, points=point_vectors)
return response

def data_search(
self, collection_names: list, query_vector: list, top_n_retrieval: int = 5, score_threshold: float = 0.7
):
"""Search data in the database"""
results = [
self.db_client.search(
collection_name=collection_name,
query_vector=query_vector,
top=top_n_retrieval,
score_threshold=score_threshold,
)
for collection_name in collection_names
]
# Note the results shall contain score key; sort the results using score key and get top 5 among them.
return results
61 changes: 61 additions & 0 deletions chatbot-core/doc_loaders.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
from dataclasses import dataclass
from typing import List

from langchain.schema import Document
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain_community.document_loaders import WebBaseLoader


@dataclass(kw_only=True)
class DocumentLoader:
"""
Base Class for Document Loaders
"""

chunk_size: int = 200
chunk_overlap: int = 20

def _get_split_documents(self, documents: List[Document]):
"""
Splits documents into multiple chunks
"""
splitter = RecursiveCharacterTextSplitter(
chunk_size=self.chunk_size, chunk_overlap=self.chunk_overlap, length_function=len
)

return splitter.split_documents(documents)


@dataclass
class LoaderFromText(DocumentLoader):
"""
Document loader for plain texts
"""

text: str

def create_document_chunks(self):
"""
Creates multiple documents from the input texts
"""
documents = [Document(page_content=self.text)]
doc_chunks = self._get_split_documents(documents=documents)
return doc_chunks


@dataclass
class LoaderFromWeb(DocumentLoader):
"""
Document loader for the web url
"""

url: str

def create_document_chunks(self):
"""
Creates multiple documents from the input url
"""
loader = WebBaseLoader(web_path=self.url)
docs = loader.load()
doc_chunks = self._get_split_documents(documents=docs)
return doc_chunks
162 changes: 162 additions & 0 deletions chatbot-core/llm.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,162 @@
import logging
from dataclasses import dataclass, field
from typing import Any, Optional

from custom_embeddings import CustomEmbeddingsWrapper
from django.conf import settings
from langchain.chains.combine_documents import create_stuff_documents_chain
from langchain.chains.history_aware_retriever import create_history_aware_retriever
from langchain.chains.retrieval import create_retrieval_chain
from langchain.memory import ConversationBufferWindowMemory
from langchain.schema import AIMessage, HumanMessage
from langchain_community.llms.ollama import Ollama
from langchain_core.prompts import ChatPromptTemplate, MessagesPlaceholder
from langchain_openai import ChatOpenAI
from langchain_qdrant import QdrantVectorStore
from qdrant_client import QdrantClient

logger = logging.getLogger(__name__)


@dataclass
class LLMBase:
"""LLM Base containing common methods"""

qdrant_client: QdrantClient = field(init=False)
llm_model: Any = field(init=False)
memory: Any = field(init=False)
embedding_model: CustomEmbeddingsWrapper = field(init=False)
rag_chain: Optional[Any] = None

def __post_init__(self, mem_key: str = "chat_history", conversation_max_window: int = 3):
self.llm_model = None
self.qdrant_client = None
self.memory = None

try:
self.qdrant_client = QdrantClient(host=settings.QDRANT_DB_HOST, port=settings.QDRANT_DB_PORT)
except Exception as e:
raise Exception(f"Qdrant client is not properly setup. {str(e)}")
self.memory = ConversationBufferWindowMemory(k=conversation_max_window, memory_key=mem_key, return_messages=True)

self.embedding_model = CustomEmbeddingsWrapper(
url=settings.EMBEDDING_MODEL_URL,
model_name=settings.EMBEDDING_MODEL_NAME,
model_type=settings.EMBEDDING_MODEL_TYPE,
base_url=settings.OLLAMA_EMBEDDING_MODEL_BASE_URL,
)

def _system_prompt_for_retrieval(self):
"""System prompt for information retrieval"""
return (
"Given a chat history and the latest user question "
"which might reference context in the chat history, "
"formulate a standalone question which can be understood "
"without the chat history. Do NOT answer the question, "
"just reformulate it if needed and otherwise return it as is."
)

def _system_prompt_for_response(self):
"""
System prompt for response generation
"""
system_prompt = """
You are an assistant for question-answering tasks.\n,
Use the following retrieved context to answer the question.\n,
If you don't get the answer from the provided context, \n,
say that 'I don't know. How can I help with the office related queries ?'
\n\n,
Context: {context}
"""
return system_prompt

def get_prompt_template_for_retrieval(self):
"""Get the prompt template"""
system_prompt = self._system_prompt_for_retrieval()
context_prompt_template = ChatPromptTemplate.from_messages(
[("system", system_prompt), MessagesPlaceholder(variable_name="chat_history"), ("human", "{input}")]
)
return context_prompt_template

def get_prompt_template_for_response(self):
"""Get the prompt template for response generation"""
system_prompt = self._system_prompt_for_response()
llm_response_prompt = ChatPromptTemplate.from_messages(
[("system", system_prompt), MessagesPlaceholder("chat_history"), ("human", "{input}")]
)
return llm_response_prompt

def get_db_retriever(self, collection_name: str, top_k_items: int = 5, score_threshold: float = 0.5):
"""Get the database retriever"""
db_retriever = QdrantVectorStore(
client=self.qdrant_client, collection_name=collection_name, embedding=self.embedding_model
)
retriever = db_retriever.as_retriever(
search_type="similarity_score_threshold", search_kwargs={"k": top_k_items, "score_threshold": score_threshold}
)
return retriever

def create_chain(self, db_collection_name: str):
"""Creates a llm chain"""
if not self.llm_model:
raise Exception("The LLM model is not loaded.")

context_prompt_template = self.get_prompt_template_for_retrieval()
response_prompt_template = self.get_prompt_template_for_response()

retriever = self.get_db_retriever(collection_name=db_collection_name)

history_aware_retriever = create_history_aware_retriever(self.llm_model, retriever, context_prompt_template)

chat_response_chain = create_stuff_documents_chain(self.llm_model, response_prompt_template)

rag_chain = create_retrieval_chain(history_aware_retriever, chat_response_chain)
return rag_chain

def execute_chain(self, query: str, db_collection_name: str = settings.QDRANT_DB_COLLECTION_NAME):
"""
Executes the chain
"""
if not self.rag_chain:
self.rag_chain = self.create_chain(db_collection_name=db_collection_name)

response = self.rag_chain.invoke({"input": query, "chat_history": self.get_message_history()["chat_history"]})
self.memory.chat_memory.add_message(HumanMessage(content=query))
self.memory.chat_memory.add_message(AIMessage(content=response["answer"]))
return response["answer"] if "answer" in response else ""

def get_message_history(self):
"""
Returns the historical conversational data
"""
return self.memory.load_memory_variables({})


@dataclass
class OpenAIHandler(LLMBase):
"""LLM handler using OpenAI for RAG"""

temperature: float = 0.2

def __post_init__(self):
super().__post_init__()
try:
self.llm_model = ChatOpenAI(model=settings.LLM_MODEL_NAME, temperature=self.temperature)
except Exception as e:
raise Exception(f"OpenAI LLM model is not successfully loaded. {str(e)}")


@dataclass
class OllamaHandler(LLMBase):
"""LLM Handler using Ollama for RAG"""

temperature: float = 0.2

def __post_init__(self):
super().__post_init__()
try:
self.llm_model = Ollama(
model=settings.LLM_MODEL_NAME, base_url=settings.LLM_OLLAMA_BASE_URL, temperature=self.temperature
)
except Exception as e:
raise Exception(f"Ollama LLM model is not successfully loaded. {str(e)}")
27 changes: 27 additions & 0 deletions chatbot-core/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
import re
from enum import Enum


class EmbeddingModelType(Enum):
"""Embedding Model Types"""

SENTENCE_TRANSFORMES = 1
OLLAMA = 2
OPENAI = 3


class LLMType(Enum):
"""LLM Types"""

OLLAMA = 1
OPENAI = 2


def preprocess_text(texts: list[str]) -> list[str]:
"""
Preprocessing of the texts
"""
pattern = r"\s+"

results = [re.sub(pattern, "", text) for text in texts]
return results
Loading

0 comments on commit 7b4a29f

Please sign in to comment.