Skip to content

Commit

Permalink
Merge pull request #375 from Marker-Inc-Korea/Feature/#372
Browse files Browse the repository at this point in the history
test for large embeddings and version control for latest chromaDB
  • Loading branch information
bwook00 authored Apr 26, 2024
2 parents 0a3071f + 45ab34f commit a598015
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 2 deletions.
7 changes: 6 additions & 1 deletion autorag/nodes/retrieval/vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@

import chromadb
import pandas as pd
from chromadb.utils.batch_utils import create_batches
from llama_index.core.embeddings import BaseEmbedding
from llama_index.embeddings.openai import OpenAIEmbedding

Expand Down Expand Up @@ -95,4 +96,8 @@ def vectordb_ingest(collection: chromadb.Collection, corpus_data: pd.DataFrame,

new_ids = new_passage['doc_id'].tolist()
embedded_contents = embedding_model.get_text_embedding_batch(new_contents, show_progress=True)
collection.add(ids=new_ids, embeddings=embedded_contents)
input_batches = create_batches(api=collection._client, ids=new_ids, embeddings=embedded_contents)
for batch in input_batches:
ids = batch[0]
embed_content = batch[1]
collection.add(ids=ids, embeddings=embed_content)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ sacrebleu # for bleu score
evaluate # for meteor and other scores
rouge_score # for rouge score
rich # for pretty logging
chromadb # for vectordb retrieval
chromadb>=0.5.0 # for vectordb retrieval
click # for cli
fastapi # for api server
uvicorn # for api server
Expand Down
26 changes: 26 additions & 0 deletions tests/autorag/nodes/retrieval/test_vectordb.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,9 @@
import pathlib
import shutil
import tempfile
import uuid
from datetime import datetime
from unittest.mock import patch

import chromadb
import pandas as pd
Expand Down Expand Up @@ -32,6 +34,15 @@ def ingested_vectordb():
yield collection


@pytest.fixture
def empty_chromadb():
with tempfile.TemporaryDirectory() as chroma_path:
db = chromadb.PersistentClient(path=chroma_path)
collection = db.create_collection(name="test_vectordb_retrieval", metadata={"hnsw:space": "cosine"})

yield collection


@pytest.fixture
def project_dir_for_vectordb_node():
with tempfile.TemporaryDirectory() as test_project_dir:
Expand Down Expand Up @@ -88,3 +99,18 @@ def test_long_text_vectordb_ingest(ingested_vectordb):
vectordb_ingest(ingested_vectordb, new_corpus_df, embedding_model)

assert ingested_vectordb.count() == 7


def mock_get_text_embedding_batch(self, texts, **kwargs):
return [[3.0, 4.1, 3.2] for _ in range(len(texts))]


@patch.object(OpenAIEmbedding, 'get_text_embedding_batch', mock_get_text_embedding_batch)
def test_long_ids_ingest(empty_chromadb):
embedding_model = OpenAIEmbedding()
content_df = pd.DataFrame({
'doc_id': [str(uuid.uuid4()) for _ in range(100_000)],
'contents': ['havertz' for _ in range(100_000)],
'metadata': [{'last_modified_datetime': datetime.now()} for _ in range(100_000)],
})
vectordb_ingest(empty_chromadb, content_df, embedding_model)

0 comments on commit a598015

Please sign in to comment.