Skip to content

Commit

Permalink
Added test cases in Indexer to improve test coverage
Browse files Browse the repository at this point in the history
  • Loading branch information
Ananya Agrawal authored and Ananya Agrawal committed Dec 16, 2024
1 parent f2251a5 commit 9e141cc
Show file tree
Hide file tree
Showing 2 changed files with 160 additions and 1 deletion.
141 changes: 140 additions & 1 deletion indexer/tests/test_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@
import os
from contextlib import asynccontextmanager
from unittest.mock import ANY, AsyncMock, MagicMock, mock_open, patch

from r2r import ChunkingConfig, R2RBuilder, R2RConfig
import asyncpg
import pytest
from lib.data_models import Indexer
from model import InternalServerException


# Mock environment variables
Expand Down Expand Up @@ -38,6 +40,39 @@ async def mock_read_file(file_path, mode):
):
import indexing

def test_parse_file():
# Mock parser functions
with patch("indexing.pdf_parser", return_value="PDF Content") as mock_pdf_parser, \
patch("indexing.docx_parser", return_value="DOCX Content") as mock_docx_parser, \
patch("indexing.xlsx_parser", return_value="XLSX Content") as mock_xlsx_parser, \
patch("indexing.json_parser", return_value="JSON Content") as mock_json_parser, \
patch("indexing.default_parser", return_value="Default Content") as mock_default_parser:

# Test PDF file
result = indexing.parse_file("test.pdf")
assert result == "PDF Content"
mock_pdf_parser.assert_called_once_with("test.pdf")

# Test DOCX file
result = indexing.parse_file("test.docx")
assert result == "DOCX Content"
mock_docx_parser.assert_called_once_with("test.docx")

# Test XLSX file
result = indexing.parse_file("test.xlsx")
assert result == "XLSX Content"
mock_xlsx_parser.assert_called_once_with("test.xlsx")

# Test JSON file
result = indexing.parse_file("test.json")
assert result == "JSON Content"
mock_json_parser.assert_called_once_with("test.json")

# Test unsupported file extension (uses default parser)
result = indexing.parse_file("test.txt")
assert result == "Default Content"
mock_default_parser.assert_called_once_with("test.txt")

def test_docx_parser():
with patch("docx2txt.process", return_value="DOCX Content") as mock_process:
result = indexing.docx_parser("test.docx")
Expand Down Expand Up @@ -80,6 +115,28 @@ async def test_text_converter_textify():
result = await converter.textify("test.txt")
assert result == "Parsed Content"

@pytest.mark.asyncio
async def test_create_pg_vector_index_if_not_exists():
# Mock the asyncpg connection and its methods
mock_connection = MagicMock(spec=asyncpg.connection.Connection)

with patch("asyncpg.connect", return_value=mock_connection):
# Initialize the DataIndexer instance
indexer = indexing.DataIndexer()

# Call the method
await indexer.create_pg_vector_index_if_not_exists()

# Assertions
mock_connection.transaction.assert_called_once()
mock_connection.execute.assert_any_call(
"ALTER TABLE langchain_pg_embedding ALTER COLUMN embedding TYPE vector(1536)"
)
mock_connection.execute.assert_any_call(
"CREATE INDEX IF NOT EXISTS langchain_embeddings_hnsw ON langchain_pg_embedding USING hnsw (embedding vector_cosine_ops)"
)
mock_connection.close.assert_called_once()

# Test DataIndexer.index method for default input
@pytest.mark.asyncio
async def test_default_data_indexer():
Expand All @@ -105,6 +162,33 @@ async def test_default_data_indexer():
await indexer.index(indexer_input)
mock_pg_vector.assert_called_once()
indexer.create_pg_vector_index_if_not_exists.assert_awaited_once()

@pytest.mark.asyncio
async def test_get_r2r():
indexer = indexing.DataIndexer()
chunk_size = 4000
chunk_overlap = 200
mock_r2r_app = MagicMock()

with patch("indexing.R2RConfig") as MockR2RConfig, patch("indexing.R2RBuilder") as MockR2RBuilder:
mock_r2r_builder = MagicMock()
MockR2RBuilder.return_value = mock_r2r_builder
mock_r2r_builder.build.return_value = mock_r2r_app

r2r_app = await indexer.get_r2r(chunk_size=chunk_size, chunk_overlap=chunk_overlap)

MockR2RConfig.assert_called_once_with(
config_data={
"chunking": ChunkingConfig(
chunk_size=chunk_size, chunk_overlap=chunk_overlap
),
}
)

MockR2RBuilder.assert_called_once_with(config=MockR2RConfig.return_value)
mock_r2r_builder.build.assert_called_once()

assert r2r_app == mock_r2r_app

# Test DataIndexer.index method for r2r input
@pytest.mark.asyncio
Expand All @@ -124,3 +208,58 @@ async def test_r2r_data_indexer():
await indexer.index(indexer_input)
mock_r2r_app.engine.aingest_files.assert_awaited_once_with(files=[ANY])
assert os.environ["POSTGRES_VECS_COLLECTION"] == "test_collection"

@pytest.mark.asyncio
async def test_get_embeddings_azure():
# Mock environment variables for Azure setup
os.environ["OPENAI_API_TYPE"] = "azure"
os.environ["AZURE_EMBEDDING_MODEL_NAME"] = "azure-model-name"
os.environ["AZURE_DEPLOYMENT_NAME"] = "azure-deployment"
os.environ["AZURE_OPENAI_ENDPOINT"] = "https://azure-endpoint"
os.environ["AZURE_OPENAI_API_KEY"] = "azure-api-key"

indexer = indexing.DataIndexer()

# Mock the AzureOpenAIEmbeddings class and its constructor
with patch("indexing.AzureOpenAIEmbeddings") as mock_azure_embeddings:
mock_instance = MagicMock()
mock_azure_embeddings.return_value = mock_instance

# Call the get_embeddings method
embeddings = await indexer.get_embeddings()

# Assert that AzureOpenAIEmbeddings was created with the correct parameters
mock_azure_embeddings.assert_called_once_with(
model="azure-model-name",
dimensions=1536,
azure_deployment="azure-deployment",
azure_endpoint="https://azure-endpoint",
openai_api_type="azure",
openai_api_key="azure-api-key"
)

# Assert the returned value is the mocked instance of AzureOpenAIEmbeddings
assert embeddings == mock_instance


@pytest.mark.asyncio
async def test_get_embeddings_openai():
# Mock environment variables for OpenAI setup
os.environ["OPENAI_API_TYPE"] = "openai"

indexer = indexing.DataIndexer()

# Mock the OpenAIEmbeddings class and its constructor
with patch("indexing.OpenAIEmbeddings") as mock_openai_embeddings:
mock_instance = MagicMock()
mock_openai_embeddings.return_value = mock_instance

# Call the get_embeddings method
embeddings = await indexer.get_embeddings()

# Assert that OpenAIEmbeddings was created with the correct parameters
mock_openai_embeddings.assert_called_once_with(client="")

# Assert the returned value is the mocked instance of OpenAIEmbeddings
assert embeddings == mock_instance

20 changes: 20 additions & 0 deletions indexer/tests/test_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,20 @@
from model import InternalServerException, ServiceUnavailableException

# Test for InternalServerException
def test_internal_server_exception():
message = "An internal server error occurred"
exception = InternalServerException(message)

assert exception.message == message
assert exception.status_code == 500
assert str(exception) == message


# Test for ServiceUnavailableException
def test_service_unavailable_exception():
message = "Service is unavailable"
exception = ServiceUnavailableException(message)

assert exception.message == message
assert exception.status_code == 503
assert str(exception) == message

0 comments on commit 9e141cc

Please sign in to comment.