diff --git a/packages/ragbits-document-search/src/ragbits/document_search/_main.py b/packages/ragbits-document-search/src/ragbits/document_search/_main.py index afbf88b5..72cd2e25 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/_main.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/_main.py @@ -16,6 +16,7 @@ from ragbits.core.vector_stores.base import VectorStoreOptions from ragbits.document_search.documents.document import Document, DocumentMeta from ragbits.document_search.documents.element import Element, ImageElement +from ragbits.document_search.documents.source_resolver import SourceResolver from ragbits.document_search.documents.sources import Source from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter from ragbits.document_search.ingestion.processor_strategies import ( @@ -197,19 +198,27 @@ async def search(self, query: str, config: SearchConfig | None = None) -> Sequen @traceable async def ingest( self, - documents: Sequence[DocumentMeta | Document | Source], + documents: str | Sequence[DocumentMeta | Document | Source], document_processor: BaseProvider | None = None, ) -> None: - """ - Ingest multiple documents. + """Ingest documents into the search index. Args: - documents: The documents or metadata of the documents to ingest. + documents: Either: + - A sequence of `Document`, `DocumentMetadata`, or `Source` objects + - A source-specific URI string (e.g., "gcs://bucket/*") to specify source location(s), for example: + - "file:///path/to/files/*.txt" + - "gcs://bucket/folder/*" + - "huggingface://dataset/split/row" document_processor: The document processor to use. If not provided, the document processor will be determined based on the document metadata. """ + if isinstance(documents, str): + sources: Sequence[DocumentMeta | Document | Source] = await SourceResolver.resolve(documents) + else: + sources = documents elements = await self.processing_strategy.process_documents( - documents, self.document_processor_router, document_processor + sources, self.document_processor_router, document_processor ) await self._remove_entries_with_same_sources(elements) await self.insert_elements(elements) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py new file mode 100644 index 00000000..7efad108 --- /dev/null +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/source_resolver.py @@ -0,0 +1,59 @@ +from collections.abc import Sequence +from typing import TYPE_CHECKING, ClassVar + +if TYPE_CHECKING: + from ragbits.document_search.documents.sources import Source + + +class SourceResolver: + """Registry for source URI protocols and their handlers. + + This class provides a mechanism to register and resolve different source protocols (like 'file://', 'gcs://', etc.) + to their corresponding Source implementations. + + Example: + >>> SourceResolver.register_protocol("gcs", GCSSource) + >>> sources = await SourceResolver.resolve("gcs://my-bucket/path/to/files/*") + """ + + _protocol_handlers: ClassVar[dict[str, type["Source"]]] = {} + + @classmethod + def register_protocol(cls, protocol: str, source_class: type["Source"]) -> None: + """Register a source class for a specific protocol. + + Args: + protocol: The protocol identifier (e.g., 'file', 'gcs', 's3') + source_class: The Source subclass that handles this protocol + """ + cls._protocol_handlers[protocol] = source_class + + @classmethod + async def resolve(cls, uri: str) -> Sequence["Source"]: + """Resolve a URI into a sequence of Source objects. + + The URI format should be: protocol://path + For example: + - file:///path/to/files/* + - gcs://bucket/prefix/* + + Args: + uri: The URI to resolve + + Returns: + A sequence of Source objects + + Raises: + ValueError: If the URI format is invalid or the protocol is not supported + """ + try: + protocol, path = uri.split("://", 1) + except ValueError as err: + raise ValueError(f"Invalid URI format: {uri}. Expected format: protocol://path") from err + + if protocol not in cls._protocol_handlers: + supported = ", ".join(sorted(cls._protocol_handlers.keys())) + raise ValueError(f"Unsupported protocol: {protocol}. Supported protocols are: {supported}") + + handler_class = cls._protocol_handlers[protocol] + return await handler_class.from_uri(path) diff --git a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py index 77282f0e..d9a32c67 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/documents/sources.py @@ -2,6 +2,7 @@ import re import tempfile from abc import ABC, abstractmethod +from collections.abc import Sequence from contextlib import suppress from pathlib import Path from typing import Any, ClassVar @@ -11,14 +12,16 @@ from pydantic_core import CoreSchema, core_schema with suppress(ImportError): - from gcloud.aio.storage import Storage + from gcloud.aio.storage import Storage as StorageClient with suppress(ImportError): from datasets import load_dataset from datasets.exceptions import DatasetNotFoundError + from ragbits.core.utils.decorators import requires_dependencies from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError +from ragbits.document_search.documents.source_resolver import SourceResolver LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR" @@ -30,6 +33,7 @@ class Source(BaseModel, ABC): # Registry of all subclasses by their unique identifier _registry: ClassVar[dict[str, type["Source"]]] = {} + protocol: ClassVar[str | None] = None @classmethod def class_identifier(cls) -> str: @@ -64,10 +68,34 @@ async def fetch(self) -> Path: The path to the source. """ + @classmethod + @abstractmethod + async def from_uri(cls, path: str) -> Sequence["Source"]: + """Create Source instances from a URI path. + + The path can contain glob patterns (asterisks) to match multiple sources, but pattern support + varies by source type. Each source implementation defines which patterns it supports: + + - LocalFileSource: Supports full glob patterns ('*', '**', etc.) via Path.glob + - GCSSource: Supports simple prefix matching with '*' at the end of path + - HuggingFaceSource: Does not support glob patterns + + Args: + path: The path part of the URI (after protocol://). Pattern support depends on source type. + + Returns: + A sequence of Source objects matching the path pattern + + Raises: + ValueError: If the path contains unsupported pattern for this source type + """ + @classmethod def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401 - Source._registry[cls.class_identifier()] = cls super().__init_subclass__(**kwargs) + Source._registry[cls.class_identifier()] = cls + if cls.protocol is not None: + SourceResolver.register_protocol(cls.protocol, cls) class SourceDiscriminator: @@ -112,6 +140,7 @@ class LocalFileSource(Source): """ path: Path + protocol: ClassVar[str] = "file" @property def id(self) -> str: @@ -151,14 +180,71 @@ def list_sources(cls, path: Path, file_pattern: str = "*") -> list["LocalFileSou """ return [cls(path=file_path) for file_path in path.glob(file_pattern)] + @classmethod + async def from_uri(cls, path: str) -> Sequence["LocalFileSource"]: + """Create LocalFileSource instances from a URI path. + + Supports full glob patterns via Path.glob: + - "**/*.txt" - all .txt files in any subdirectory + - "*.py" - all Python files in the current directory + - "**/*" - all files in any subdirectory + - '?' matches exactly one character + + Args: + path: The path part of the URI (after file://). Pattern support depends on source type. + + Returns: + A sequence of LocalFileSource objects + """ + path_obj: Path = Path(path) + base_path, pattern = cls._split_path_and_pattern(path=path_obj) + if base_path.is_file(): + return [cls(path=base_path)] + if not pattern: + return [] + return [cls(path=f) for f in base_path.glob(pattern) if f.is_file()] + + @staticmethod + def _split_path_and_pattern(path: Path) -> tuple[Path, str]: + parts = path.parts + # Find the first part containing '*' or '?' + for i, part in enumerate(parts): + if "*" in part or "?" in part: + base_path = Path(*parts[:i]) + pattern = str(Path(*parts[i:])) + return base_path, pattern + return path, "" + class GCSSource(Source): - """ - An object representing a GCS file source. - """ + """An object representing a GCS file source.""" bucket: str object_name: str + protocol: ClassVar[str] = "gcs" + _storage: "StorageClient | None" = None # Storage client for dependency injection + + @classmethod + def set_storage(cls, storage: "StorageClient | None") -> None: + """Set the storage client for all instances. + + Args: + storage: The `gcloud-aio-storage` `Storage` object to use as the storage client. + By default, the object will be created automatically. + """ + cls._storage = storage + + @classmethod + @requires_dependencies(["gcloud.aio.storage"], "gcs") + async def _get_storage(cls) -> "StorageClient": + """Get the storage client. + + Returns: + The storage client to use. If none was injected, creates a new one. + """ + if cls._storage is None: + cls._storage = StorageClient() + return cls._storage @property def id(self) -> str: @@ -192,8 +278,8 @@ async def fetch(self) -> Path: path = bucket_local_dir / self.object_name if not path.is_file(): - async with Storage() as client: # type: ignore - # TODO: Add error handling for download + storage = await self._get_storage() + async with storage as client: content = await client.download(self.bucket, self.object_name) Path(bucket_local_dir / self.object_name).parent.mkdir(parents=True, exist_ok=True) with open(path, mode="wb+") as file_object: @@ -204,8 +290,7 @@ async def fetch(self) -> Path: @classmethod @requires_dependencies(["gcloud.aio.storage"], "gcs") async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: - """ - List all sources in the given GCS bucket, matching the prefix. + """List all sources in the given GCS bucket, matching the prefix. Args: bucket: The GCS bucket. @@ -217,12 +302,47 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]: Raises: ImportError: If the required 'gcloud-aio-storage' package is not installed """ - async with Storage() as client: - objects = await client.list_objects(bucket, params={"prefix": prefix}) - sources = [] - for obj in objects["items"]: - sources.append(cls(bucket=bucket, object_name=obj["name"])) - return sources + async with await cls._get_storage() as storage: + result = await storage.list_objects(bucket, params={"prefix": prefix}) + items = result.get("items", []) + return [cls(bucket=bucket, object_name=item["name"]) for item in items] + + @classmethod + async def from_uri(cls, path: str) -> Sequence["GCSSource"]: + """Create GCSSource instances from a URI path. + + Supports simple prefix matching with '*' at the end of path. + For example: + - "bucket/folder/*" - matches all files in the folder + - "bucket/folder/prefix*" - matches all files starting with prefix + + More complex patterns like '**' or '?' are not supported. + + Args: + path: The path part of the URI (after gcs://). Can end with '*' for pattern matching. + + Returns: + A sequence of GCSSource objects matching the pattern + + Raises: + ValueError: If an unsupported pattern is used + """ + if "**" in path or "?" in path: + raise ValueError( + "GCSSource only supports '*' at the end of path. Patterns like '**' or '?' are not supported." + ) + + # Split into bucket and prefix + bucket, prefix = path.split("/", 1) if "/" in path else (path, "") + + if "*" in prefix: + if not prefix.endswith("*"): + raise ValueError(f"GCSSource only supports '*' at the end of path. Invalid pattern: {prefix}") + # Remove the trailing * for GCS prefix listing + prefix = prefix[:-1] + return await cls.list_sources(bucket=bucket, prefix=prefix) + + return [cls(bucket=bucket, object_name=prefix)] class HuggingFaceSource(Source): @@ -233,6 +353,7 @@ class HuggingFaceSource(Source): path: str split: str = "train" row: int + protocol: ClassVar[str] = "huggingface" @property def id(self) -> str: @@ -280,6 +401,33 @@ async def fetch(self) -> Path: return path + @classmethod + async def from_uri(cls, path: str) -> Sequence["HuggingFaceSource"]: + """Create HuggingFaceSource instances from a URI path. + + Pattern matching is not supported. The path must be in the format: + huggingface://dataset_path/split/row + + Args: + path: The path part of the URI (after huggingface://) + + Returns: + A sequence containing a single HuggingFaceSource + + Raises: + ValueError: If the path contains patterns or has invalid format + """ + if "*" in path or "?" in path: + raise ValueError( + "HuggingFaceSource does not support patterns. Path must be in format: dataset_path/split/row" + ) + + try: + dataset_path, split, row = path.split("/") + return [cls(path=dataset_path, split=split, row=int(row))] + except ValueError as err: + raise ValueError("Invalid HuggingFace path format. Expected: dataset_path/split/row") from err + @classmethod async def list_sources(cls, path: str, split: str) -> list["HuggingFaceSource"]: """ diff --git a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py index 69e17864..91d62b9a 100644 --- a/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py +++ b/packages/ragbits-document-search/src/ragbits/document_search/ingestion/document_processor.py @@ -1,5 +1,5 @@ import copy -from collections.abc import Callable +from collections.abc import Callable, Mapping, MutableMapping from typing import cast from ragbits.core.utils.config_handling import ObjectContructionConfig @@ -10,10 +10,10 @@ from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider # TODO consider defining with some defined schema -ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] +ProvidersConfig = Mapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] -DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = { +DEFAULT_PROVIDERS_CONFIG: MutableMapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] = { DocumentType.TXT: UnstructuredDefaultProvider, DocumentType.MD: UnstructuredDefaultProvider, DocumentType.PDF: UnstructuredPdfProvider, @@ -43,7 +43,7 @@ class DocumentProcessorRouter: metadata such as the document type. """ - def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]): + def __init__(self, providers: ProvidersConfig): self._providers = providers @staticmethod @@ -71,7 +71,7 @@ def from_dict_to_providers_config(dict_config: dict[str, ObjectContructionConfig return providers_config @classmethod - def from_config(cls, providers_config: ProvidersConfig | None = None) -> "DocumentProcessorRouter": + def from_config(cls, providers: ProvidersConfig | None = None) -> "DocumentProcessorRouter": """ Create a DocumentProcessorRouter from a configuration. If the configuration is not provided, the default configuration will be used. If the configuration is provided, it will be merged with the default configuration, @@ -83,14 +83,16 @@ def from_config(cls, providers_config: ProvidersConfig | None = None) -> "Docume } Args: - providers_config: The dictionary with the providers configuration, mapping the document types to the + providers: The dictionary with the providers configuration, mapping the document types to the provider class. Returns: The DocumentProcessorRouter. """ - config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG) - config.update(providers_config if providers_config is not None else {}) + config: MutableMapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] = copy.deepcopy( + DEFAULT_PROVIDERS_CONFIG + ) + config.update(providers if providers is not None else {}) return cls(providers=config) diff --git a/packages/ragbits-document-search/tests/unit/test_document_search.py b/packages/ragbits-document-search/tests/unit/test_document_search.py index 181b420b..e5fc00d4 100644 --- a/packages/ragbits-document-search/tests/unit/test_document_search.py +++ b/packages/ragbits-document-search/tests/unit/test_document_search.py @@ -1,6 +1,9 @@ +import os import tempfile -from collections.abc import Callable +from collections.abc import Callable, Mapping from pathlib import Path +from typing import cast +from unittest import mock from unittest.mock import AsyncMock import pytest @@ -8,11 +11,20 @@ from ragbits.core.vector_stores.in_memory import InMemoryVectorStore from ragbits.document_search import DocumentSearch from ragbits.document_search._main import SearchConfig -from ragbits.document_search.documents.document import Document, DocumentMeta, DocumentType +from ragbits.document_search.documents.document import ( + Document, + DocumentMeta, + DocumentType, +) from ragbits.document_search.documents.element import TextElement -from ragbits.document_search.documents.sources import LocalFileSource +from ragbits.document_search.documents.sources import ( + GCSSource, + LocalFileSource, +) from ragbits.document_search.ingestion.document_processor import DocumentProcessorRouter -from ragbits.document_search.ingestion.processor_strategies.batched import BatchedAsyncProcessing +from ragbits.document_search.ingestion.processor_strategies.batched import ( + BatchedAsyncProcessing, +) from ragbits.document_search.ingestion.providers import BaseProvider from ragbits.document_search.ingestion.providers.dummy import DummyProvider @@ -49,14 +61,14 @@ async def test_document_search_from_config(document: DocumentMeta, expected: str first_result = results[0] assert isinstance(first_result, TextElement) - assert first_result.content == expected # type: ignore + assert first_result.content == expected async def test_document_search_ingest_from_source(): embeddings_mock = AsyncMock() embeddings_mock.embed_text.return_value = [[0.1, 0.1]] - providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()} + providers: Mapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {DocumentType.TXT: DummyProvider()} router = DocumentProcessorRouter.from_config(providers) document_search = DocumentSearch( @@ -76,7 +88,7 @@ async def test_document_search_ingest_from_source(): first_result = results[0] assert isinstance(first_result, TextElement) - assert first_result.content == "Name of Peppa's brother is George" # type: ignore + assert first_result.content == "Name of Peppa's brother is George" @pytest.mark.parametrize( @@ -102,7 +114,7 @@ async def test_document_search_ingest(document: DocumentMeta | Document): first_result = results[0] assert isinstance(first_result, TextElement) - assert first_result.content == "Name of Peppa's brother is George" # type: ignore + assert first_result.content == "Name of Peppa's brother is George" async def test_document_search_insert_elements(): @@ -125,7 +137,7 @@ async def test_document_search_insert_elements(): first_result = results[0] assert isinstance(first_result, TextElement) - assert first_result.content == "Name of Peppa's brother is George" # type: ignore + assert first_result.content == "Name of Peppa's brother is George" async def test_document_search_with_no_results(): @@ -150,7 +162,8 @@ async def test_document_search_with_search_config(): results = await document_search.search("Peppa's brother", config=SearchConfig(vector_store_kwargs={"k": 1})) assert len(results) == 1 - assert results[0].content == "Name of Peppa's brother is George" # type: ignore + assert isinstance(results[0], TextElement) + assert cast(TextElement, results[0]).content == "Name of Peppa's brother is George" async def test_document_search_ingest_multiple_from_sources(): @@ -165,7 +178,8 @@ async def test_document_search_ingest_multiple_from_sources(): results = await document_search.search("foo") assert len(results) == 2 - assert {result.content for result in results} == {"foo", "bar"} # type: ignore + assert all(isinstance(result, TextElement) for result in results) + assert {cast(TextElement, result).content for result in results} == {"foo", "bar"} async def test_document_search_with_batched(): @@ -202,3 +216,292 @@ async def test_document_search_with_batched(): assert len(await vectore_store.list()) == 12 assert len(results) == 12 + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_uri_basic(): + # Setup + with tempfile.TemporaryDirectory() as temp_dir: + test_file = Path(temp_dir) / "test.txt" + test_file.write_text("Test content") + + document_search = DocumentSearch.from_config(CONFIG) + + # Test ingesting from URI + await document_search.ingest(f"file://{test_file}") + + # Verify + results = await document_search.search("Test content") + assert len(results) == 1 + assert isinstance(results[0], TextElement) + assert isinstance(results[0].document_meta.source, LocalFileSource) + assert str(cast(LocalFileSource, results[0].document_meta.source).path) == str(test_file) + assert results[0].content == "Test content" + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ("pattern", "dir_pattern", "search_query", "expected_contents", "expected_filenames"), + [ + ( + "test*.txt", + None, + "test content", # We search for "test content" + {"First test content", "Second test content"}, + {"test1.txt", "test2.txt"}, + ), + ( + "othe?.txt", + None, + "Other content", + {"Other content"}, + {"other.txt"}, + ), + ( + "te??*.txt", + "**", + "test content", # We search for "test content" + {"First test content", "Second test content"}, + {"test1.txt", "test2.txt"}, + ), + ], +) +async def test_document_search_ingest_from_uri_with_wildcard( + pattern: str, dir_pattern: str | None, search_query: str, expected_contents: set, expected_filenames: set +): + # Setup temporary directory + with tempfile.TemporaryDirectory() as temp_dir: + # Create multiple test files + test_files = [ + (Path(temp_dir) / "test1.txt", "First test content"), + (Path(temp_dir) / "test2.txt", "Second test content"), + (Path(temp_dir) / "other.txt", "Other content"), + ] + for file_path, content in test_files: + file_path.write_text(content) + + document_search = DocumentSearch.from_config(CONFIG) + + # Use the parametrized glob pattern + dir_pattern = f"{str(Path(temp_dir).parent)}/{dir_pattern}" if dir_pattern is not None else temp_dir + await document_search.ingest(f"file://{dir_pattern}/{pattern}") + + # Perform the search + results = await document_search.search(search_query) + + # Check that we have the expected number of results + assert len(results) == len(expected_contents), ( + f"Expected {len(expected_contents)} result(s) but got {len(results)}" + ) + + # Verify each result is a TextElement + assert all(isinstance(result, TextElement) for result in results) + + # Collect the actual text contents + contents = {cast(TextElement, result).content for result in results} + assert contents == expected_contents, f"Expected contents: {expected_contents}, got: {contents}" + + # Verify the sources (file paths) match + sources = {str(cast(LocalFileSource, result.document_meta.source).path).split("/")[-1] for result in results} + # We compare only the filenames; if you need full paths, compare the full str(...) instead + assert sources == expected_filenames, f"Expected sources: {expected_filenames}, got: {sources}" + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_basic(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock(return_value=b"GCS test content") + storage_mock.list_objects = mock.AsyncMock( + return_value={"items": [{"name": "folder/test1.txt"}, {"name": "folder/test2.txt"}]} + ) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test single file + await document_search.ingest("gcs://test-bucket/folder/test1.txt") + results = await document_search.search("GCS test content") + assert len(results) == 1 + assert isinstance(results[0], TextElement) + assert results[0].content == "GCS test content" + + # Clean up + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_with_wildcard(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock(side_effect=[b"GCS test content 1", b"GCS test content 2"]) + storage_mock.list_objects = mock.AsyncMock( + return_value={"items": [{"name": "folder/test1.txt"}, {"name": "folder/test2.txt"}]} + ) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test wildcard ingestion + await document_search.ingest("gcs://test-bucket/folder/*") + + # Verify both files were ingested + results = await document_search.search("GCS test content") + assert len(results) == 2 + + # Verify first file + assert isinstance(results[0], TextElement) + assert results[0].content == "GCS test content 1" + + # Verify second file + assert isinstance(results[1], TextElement) + assert results[1].content == "GCS test content 2" + + # Clean up + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock(return_value=b"") + storage_mock.list_objects = mock.AsyncMock(return_value={"items": []}) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + GCSSource.set_storage(storage_mock) + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_gcs_uri_invalid_pattern(): + # Create mock storage client + storage_mock = mock.AsyncMock() + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + + # Create mock storage factory + mock_storage = mock.Mock() + mock_storage.return_value = storage_mock + + with tempfile.TemporaryDirectory() as temp_dir: + # Set up local storage dir + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + + # Inject the mock storage + GCSSource.set_storage(mock_storage()) + + document_search = DocumentSearch.from_config(CONFIG) + + # Test invalid patterns + with pytest.raises(ValueError, match="GCSSource only supports '\\*' at the end of path"): + await document_search.ingest("gcs://test-bucket/folder/**.txt") + + with pytest.raises(ValueError, match="GCSSource only supports '\\*' at the end of path"): + await document_search.ingest("gcs://test-bucket/folder/test?.txt") + + with pytest.raises(ValueError, match="GCSSource only supports '\\*' at the end of path"): + await document_search.ingest("gcs://test-bucket/folder/test*file.txt") + + # Test empty list response + storage_mock.list_objects = mock.AsyncMock(return_value={"items": []}) + await document_search.ingest("gcs://test-bucket/folder/*") + results = await document_search.search("GCS test content") + assert len(results) == 0 + + # Clean up + storage_mock = mock.AsyncMock() + storage_mock.download = mock.AsyncMock(return_value=b"") + storage_mock.list_objects = mock.AsyncMock(return_value={"items": []}) + storage_mock.__aenter__ = mock.AsyncMock(return_value=storage_mock) + storage_mock.__aexit__ = mock.AsyncMock() + GCSSource.set_storage(storage_mock) + del os.environ["LOCAL_STORAGE_DIR"] + + +@pytest.mark.asyncio +async def test_document_search_ingest_from_huggingface_uri_basic(): + # Create mock data + mock_data = [ + { + "content": "HuggingFace test content", + "source": "dataset_name/train/test.txt", # Must be .txt for TextDocument + } + ] + + # Create a simple dataset class that supports skip/take + class MockDataset: + def __init__(self, data: list): + self.data = data + self.current_index = 0 + + def skip(self, n: int) -> "MockDataset": + self.current_index = n + return self + + def take(self, n: int) -> "MockDataset": + return self + + def __iter__(self): + if self.current_index < len(self.data): + return iter(self.data[self.current_index : self.current_index + 1]) + return iter([]) + + # Mock dataset loading and embeddings + dataset = MockDataset(mock_data) + embeddings_mock = AsyncMock() + embeddings_mock.embed_text.return_value = [[0.1, 0.1]] # Non-zero embeddings + + # Create providers dict with actual provider instance + providers: Mapping[DocumentType, BaseProvider] = {DocumentType.TXT: DummyProvider()} + + # Mock vector store to track operations + vector_store = InMemoryVectorStore() + + # Create a temporary directory for storing test files + with tempfile.TemporaryDirectory() as temp_dir: + # Set the environment variable for local storage + os.environ["LOCAL_STORAGE_DIR"] = temp_dir + storage_dir = Path(temp_dir) + + # Create the source directory and file + source_dir = storage_dir / "dataset_name/train" + source_dir.mkdir(parents=True, exist_ok=True) + source_file = source_dir / "test.txt" + with open(source_file, mode="w", encoding="utf-8") as file: + file.write("HuggingFace test content") + + with ( + mock.patch("ragbits.document_search.documents.sources.load_dataset", return_value=dataset), + mock.patch("ragbits.document_search.documents.sources.get_local_storage_dir", return_value=storage_dir), + ): + document_search = DocumentSearch( + embedder=embeddings_mock, + vector_store=vector_store, + document_processor_router=DocumentProcessorRouter.from_config(providers), + ) + + await document_search.ingest("huggingface://dataset_name/train/0") + + results = await document_search.search("HuggingFace test content") + assert len(results) == 1 + assert isinstance(results[0], TextElement) + assert results[0].content == "HuggingFace test content" diff --git a/packages/ragbits-document-search/tests/unit/test_sources.py b/packages/ragbits-document-search/tests/unit/test_sources.py index 1c90df6e..ba6785dd 100644 --- a/packages/ragbits-document-search/tests/unit/test_sources.py +++ b/packages/ragbits-document-search/tests/unit/test_sources.py @@ -1,22 +1,91 @@ import os from pathlib import Path +from types import TracebackType +from typing import Any, TypeVar from unittest.mock import MagicMock, patch -from ragbits.document_search.documents.sources import LOCAL_STORAGE_DIR_ENV, GCSSource, HuggingFaceSource +from aiohttp import ClientSession + +from ragbits.document_search.documents.sources import ( + LOCAL_STORAGE_DIR_ENV, + GCSSource, + HuggingFaceSource, +) os.environ[LOCAL_STORAGE_DIR_ENV] = Path(__file__).parent.as_posix() +try: + from gcloud.aio.storage import Storage as StorageClient +except ImportError: + StorageClient = TypeVar("StorageClient") # type: ignore + + +class MockStorage(StorageClient): + """Mock GCS storage client.""" + + def __init__(self) -> None: + """Initialize mock storage.""" + self.objects: dict[str, bytes] = {} + self.downloaded_files: list[tuple[str, str]] = [] + + async def download( + self, + bucket: str, + object_name: str, + *, + headers: dict[str, Any] | None = None, + timeout: int = 60, + session: ClientSession | None = None, + ) -> bytes: + """Mock download method.""" + key = f"{bucket}/{object_name}" + self.downloaded_files.append((bucket, object_name)) + return self.objects.get(key, b"This is the content of the file.") + + async def list_objects( + self, + bucket: str, + *, + params: dict[str, str] | None = None, + headers: dict[str, Any] | None = None, + session: ClientSession | None = None, + timeout: int = 60, + ) -> dict[str, Any]: + """Mock list_objects method.""" + prefix = params.get("prefix", "") if params else "" + items = [] + for key in self.objects: + if key.startswith(f"{bucket}/{prefix}"): + items.append({"name": key.split("/", 1)[1]}) + return {"items": items} + + async def __aenter__(self) -> "MockStorage": + """Enter async context.""" + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + """Exit async context.""" + pass + + async def test_gcs_source_fetch() -> None: - data = b"This is the content of the file." - source = GCSSource(bucket="", object_name="doc.md") + """Test fetching a file from GCS.""" + mock_storage = MockStorage() + source = GCSSource(bucket="test-bucket", object_name="doc.md") + source.set_storage(mock_storage) - with patch("ragbits.document_search.documents.sources.Storage.download", return_value=data): - path = await source.fetch() + path = await source.fetch() - assert source.id == "gcs:gs:///doc.md" + assert source.id == "gcs:gs://test-bucket/doc.md" assert path.name == "doc.md" assert path.read_text() == "This is the content of the file." + assert mock_storage.downloaded_files == [("test-bucket", "doc.md")] path.unlink()