Skip to content

Commit

Permalink
updating stuff.
Browse files Browse the repository at this point in the history
  • Loading branch information
mfmezger committed Jun 1, 2024
1 parent b6e519b commit 0f7d6d2
Show file tree
Hide file tree
Showing 5 changed files with 29 additions and 110 deletions.
7 changes: 2 additions & 5 deletions agent/api.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
create_tmp_folder,
validate_token,
)
from agent.utils.vdb import initialize_aleph_alpha_vector_db, initialize_cohere_vector_db, initialize_gpt4all_vector_db, initialize_open_ai_vector_db, load_vec_db_conn
from agent.utils.vdb import initialize_all_vector_dbs, load_vec_db_conn, generate_collection

nltk.download("punkt")
# add file logger for loguru
Expand Down Expand Up @@ -427,10 +427,7 @@ def delete(


# initialize the databases
initialize_open_ai_vector_db()
initialize_aleph_alpha_vector_db()
initialize_gpt4all_vector_db()
initialize_cohere_vector_db()
initialize_all_vector_dbs()

# for debugging useful.
if __name__ == "__main__":
Expand Down
4 changes: 2 additions & 2 deletions agent/backend/aleph_alpha_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@
SearchParams,
)
from agent.utils.utility import extract_text_from_langchain_documents, generate_prompt, load_prompt_template
from agent.utils.vdb import generate_collection_aleph_alpha, init_vdb
from agent.utils.vdb import init_vdb

nltk.download("punkt") # This needs to be installed for the tokenizer to work.
load_dotenv()
Expand Down Expand Up @@ -98,7 +98,7 @@ def create_collection(self, name: str) -> bool:
name (str): The name of the new collection.
"""
generate_collection_aleph_alpha(self.vector_db.client, name, self.cfg.aleph_alpha_embeddings.size)
generate_collection(self.vector_db.client, name, self.cfg.aleph_alpha_embeddings.size)
return True

def summarize_text(self, text: str) -> str:
Expand Down
4 changes: 2 additions & 2 deletions agent/backend/ollama_service.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,11 +39,11 @@ def __init__(self, cfg: DictConfig, collection_name: str | None, token: str | No
if collection_name:
self.collection_name = collection_name
else:
self.collection_name = self.cfg.qdrant.collection_name_Ollama
self.collection_name = self.cfg.qdrant.collection_name_ollama

embedding = OllamaEmbeddings(model=self.cfg.ollama_embeddings.embedding_model_name)

template = load_prompt_template(prompt_name="ollama_chat.j2", task="chat")
template = load_prompt_template(prompt_name="cohere_chat.j2", task="chat")
self.prompt = ChatPromptTemplate.from_template(template=template, template_format="jinja2")

self.vector_db = init_vdb(self.cfg, self.collection_name, embedding=embedding)
Expand Down
120 changes: 19 additions & 101 deletions agent/utils/vdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,128 +39,46 @@ def load_vec_db_conn(cfg: DictConfig) -> QdrantClient:
"""Load the Vector Database Connection."""
return QdrantClient(cfg.qdrant.url, port=cfg.qdrant.port, api_key=os.getenv("QDRANT_API_KEY"), prefer_grpc=cfg.qdrant.prefer_grpc), cfg


def initialize_aleph_alpha_vector_db() -> None:
"""Initializes the Aleph Alpha vector db."""
qdrant_client, cfg = load_vec_db_conn()
try:
qdrant_client.get_collection(collection_name=cfg.qdrant.collection_name_aa)
logger.info(f"SUCCESS: Collection {cfg.qdrant.collection_name_aa} already exists.")
except UnexpectedResponse:
generate_collection_aleph_alpha(qdrant_client, collection_name=cfg.qdrant.collection_name_aa, embeddings_size=cfg.aleph_alpha_embeddings.size)


def generate_collection_aleph_alpha(qdrant_client: Qdrant, collection_name: str, embeddings_size: int) -> None:
"""Generate a collection for the Aleph Alpha Backend.
Args:
----
qdrant_client (Qdrant): Qdrant Connection Client.
collection_name (str): Name of the Collection in the VDB.
embeddings_size (int): SIze of the Embeddings
"""
qdrant_client.recreate_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
)
logger.info(f"SUCCESS: Collection {collection_name} created.")


def initialize_open_ai_vector_db() -> None:
"""Initializes the OpenAI vector db.
def initialize_vector_db(collection_name: str, embeddings_size: int) -> None:
"""Initializes the vector db for a given backend.
Args:
----
cfg (DictConfig): Configuration from the file
"""
qdrant_client, cfg = load_vec_db_conn()

try:
qdrant_client.get_collection(collection_name=cfg.qdrant.collection_name_openai)
logger.info(f"SUCCESS: Collection {cfg.qdrant.collection_name_openai} already exists.")
except UnexpectedResponse:
generate_collection_openai(qdrant_client, collection_name=cfg.qdrant.collection_name_openai)


def generate_collection_openai(qdrant_client: Qdrant, collection_name: str) -> None:
"""Generate a collection for the OpenAI Backend.
Args:
----
qdrant_client (Qdrant): Qdrant Client Langchain.
collection_name (str): Name of the Collection
embeddings_size (int): Size of the Embeddings
"""
qdrant_client.recreate_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=1536, distance=models.Distance.COSINE),
)
logger.info(f"SUCCESS: Collection {collection_name} created.")


def initialize_gpt4all_vector_db() -> None:
"""Initializes the GPT4ALL vector db.
Args:
----
cfg (DictConfig): Configuration from the file
"""
qdrant_client, cfg = load_vec_db_conn()
qdrant_client, _ = load_vec_db_conn()

try:
qdrant_client.get_collection(collection_name=cfg.qdrant.collection_name_gpt4all)
logger.info(f"SUCCESS: Collection {cfg.qdrant.collection_name_gpt4all} already exists.")
qdrant_client.get_collection(collection_name=collection_name)
logger.info(f"SUCCESS: Collection {collection_name} already exists.")
except UnexpectedResponse:
generate_collection_gpt4all(qdrant_client, collection_name=cfg.qdrant.collection_name_gpt4all)
generate_collection(qdrant_client, collection_name=collection_name, embeddings_size=embeddings_size)


def generate_collection_gpt4all(qdrant_client: Qdrant, collection_name: str) -> None:
"""Generate a collection for the GPT4ALL Backend.
def generate_collection(qdrant_client: Qdrant, collection_name: str, embeddings_size: int) -> None:
"""Generate a collection for a given backend.
Args:
----
qdrant_client (Qdrant): Qdrant Client
collection_name (str): Name of the Collection
embeddings_size (int): Size of the Embeddings
"""
qdrant_client.recreate_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=2048, distance=models.Distance.COSINE),
vectors_config=models.VectorParams(size=embeddings_size, distance=models.Distance.COSINE),
)
logger.info(f"SUCCESS: Collection {collection_name} created.")


def initialize_cohere_vector_db() -> None:
"""Initializes the Cohere vector db.
Args:
----
cfg (DictConfig): Configuration from the file
"""
qdrant_client, cfg = load_vec_db_conn()

try:
qdrant_client.get_collection(collection_name=cfg.qdrant.collection_name_cohere)
logger.info(f"SUCCESS: Collection {cfg.qdrant.collection_name_cohere} already exists.")
except UnexpectedResponse:
generate_collection_cohere(qdrant_client, collection_name=cfg.qdrant.collection_name_cohere)


def generate_collection_cohere(qdrant_client: Qdrant, collection_name: str) -> None:
"""Generate a collection for the OpenAI Backend.
Args:
----
qdrant_client (Qdrant): Qdrant Client Langchain.
collection_name (str): Name of the Collection
"""
qdrant_client.recreate_collection(
collection_name=collection_name,
vectors_config=models.VectorParams(size=1024, distance=models.Distance.COSINE),
)
logger.info(f"SUCCESS: Collection {collection_name} created.")
@load_config("config/main.yml")
def initialize_all_vector_dbs(cfg: DictConfig) -> None:
"""Initializes all vector dbs."""
initialize_vector_db(cfg.qdrant.collection_name_aleph_alpha, cfg.aleph_alpha_embeddings.size)
initialize_vector_db(cfg.qdrant.collection_name_openai, cfg.openai_embeddings.size)
initialize_vector_db(cfg.qdrant.collection_name_gpt4all, 2048)
initialize_vector_db(cfg.qdrant.collection_name_cohere, cfg.cohere_embeddings.size)
initialize_vector_db(cfg.qdrant.collection_name_ollama, cfg.ollama_embeddings.size)
4 changes: 4 additions & 0 deletions config/main.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@ qdrant:
collection_name_gpt4all: gpt4all
collection_name_openai: openai
collection_name_cohere: cohere
collection_name_ollama: ollama

aleph_alpha_embeddings:
normalize: True
Expand All @@ -19,10 +20,12 @@ ollama_embeddings:

ollama:
model: phi3
size: 768

openai_embeddings:
azure: False
embedding_model_name: text-embedding-ada-002
size: 1536
openai_api_version: 2024-02-15-preview

aleph_alpha_completion:
Expand Down Expand Up @@ -53,3 +56,4 @@ openai_completion:

cohere_embeddings:
embedding_model_name: "embed-multilingual-v3.0"
size: 2048

0 comments on commit 0f7d6d2

Please sign in to comment.