Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
maciejklimek committed Dec 30, 2024
1 parent 13da8bf commit d3b1e38
Show file tree
Hide file tree
Showing 5 changed files with 195 additions and 99 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,7 @@ async def ingest(

sources = await SourceResolver.resolve(documents)
else:
sources = documents
sources = documents # type: ignore[assignment]

elements = await self.processing_strategy.process_documents(
sources, self.document_processor_router, document_processor
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,8 @@
from collections.abc import Sequence
from contextlib import suppress
from pathlib import Path
from typing import Any, ClassVar
from types import TracebackType
from typing import Any, ClassVar, Protocol, runtime_checkable

from pydantic import BaseModel, GetCoreSchemaHandler, computed_field
from pydantic.alias_generators import to_snake
Expand All @@ -17,12 +18,56 @@
from datasets import load_dataset
from datasets.exceptions import DatasetNotFoundError

from aiohttp import ClientSession

from ragbits.core.utils.decorators import requires_dependencies
from ragbits.document_search.documents.exceptions import SourceConnectionError, SourceNotFoundError

LOCAL_STORAGE_DIR_ENV = "LOCAL_STORAGE_DIR"


@runtime_checkable
class StorageProtocol(Protocol):
"""Protocol for storage clients."""

async def download(
self,
bucket: str,
object_name: str,
*,
headers: dict[str, Any] | None = None,
timeout: int = 60,
session: ClientSession | None = None,
) -> bytes:
"""Download a file from storage."""
...

async def list_objects(
self,
bucket: str,
*,
params: dict[str, str] | None = None,
headers: dict[str, Any] | None = None,
session: ClientSession | None = None,
timeout: int = 60,
) -> dict[str, Any]:
"""List objects in storage."""
...

async def __aenter__(self) -> "StorageProtocol":
"""Enter async context."""
...

async def __aexit__(
self,
exc_type: type[BaseException] | None,
exc_val: BaseException | None,
exc_tb: TracebackType | None,
) -> None:
"""Exit async context."""
...


class Source(BaseModel, ABC):
"""
An object representing a source.
Expand Down Expand Up @@ -88,7 +133,7 @@ async def from_uri(cls, path: str) -> Sequence["Source"]:
"""

@classmethod
def __init_subclass__(cls, **kwargs) -> None:
def __init_subclass__(cls, **kwargs: Any) -> None: # noqa: ANN401
super().__init_subclass__(**kwargs)
Source._registry[cls.class_identifier()] = cls
if cls.protocol is not None:
Expand Down Expand Up @@ -190,18 +235,18 @@ async def from_uri(cls, path: str) -> Sequence["LocalFileSource"]:
A sequence of LocalFileSource objects
"""
# Handle absolute paths
path = Path(path)
if not path.is_absolute():
path_obj: Path = Path(path)
if not path_obj.is_absolute():
# For relative paths, use current directory as base
path = Path.cwd() / path
path_obj = Path.cwd() / path_obj

if "*" in str(path):
if "*" in str(path_obj):
# If path contains wildcards, use its parent as base
base_path = path.parent
pattern = path.name
base_path = path_obj.parent
pattern = path_obj.name
return [cls(path=file_path) for file_path in base_path.glob(pattern)]

return [cls(path=path)]
return [cls(path=path_obj)]


class GCSSource(Source):
Expand All @@ -210,28 +255,28 @@ class GCSSource(Source):
bucket: str
object_name: str
protocol: ClassVar[str] = "gcs"
_storage: "StorageClient | None" = None # Storage client for dependency injection
_storage: "StorageProtocol | None" = None # Storage client for dependency injection

@classmethod
def set_storage(cls, storage: "StorageClient") -> None:
def set_storage(cls, storage: "StorageProtocol | None") -> None:
"""Set the storage client for all instances.
Args:
storage: The storage client to use (in tests, this can be a mock)
"""
cls._storage = storage

async def _get_storage(self) -> "StorageClient":
@classmethod
@requires_dependencies(["gcloud.aio.storage"], "gcs")
async def _get_storage(cls) -> "StorageProtocol":
"""Get the storage client.
Returns:
The storage client to use. If none was injected, creates a new one.
"""
if self._storage is not None:
return self._storage

from gcloud.aio.storage import Storage
return Storage()
if cls._storage is None:
cls._storage = StorageClient()
return cls._storage

@property
def id(self) -> str:
Expand Down Expand Up @@ -277,8 +322,7 @@ async def fetch(self) -> Path:
@classmethod
@requires_dependencies(["gcloud.aio.storage"], "gcs")
async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
"""
List all sources in the given GCS bucket, matching the prefix.
"""List all sources in the given GCS bucket, matching the prefix.
Args:
bucket: The GCS bucket.
Expand All @@ -290,15 +334,10 @@ async def list_sources(cls, bucket: str, prefix: str = "") -> list["GCSSource"]:
Raises:
ImportError: If the required 'gcloud-aio-storage' package is not installed
"""
# Create a temporary instance just to get the storage client
temp_source = cls(bucket=bucket, object_name=prefix)
storage = await temp_source._get_storage()
async with storage as client:
objects = await client.list_objects(bucket, params={"prefix": prefix})
sources = []
for obj in objects["items"]:
sources.append(cls(bucket=bucket, object_name=obj["name"]))
return sources
async with await cls._get_storage() as storage:
result = await storage.list_objects(bucket, params={"prefix": prefix})
items = result.get("items", [])
return [cls(bucket=bucket, object_name=item["name"]) for item in items]

@classmethod
async def from_uri(cls, path: str) -> Sequence["GCSSource"]:
Expand Down
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import copy
from collections.abc import Callable
from collections.abc import Callable, Mapping, MutableMapping
from typing import cast

from ragbits.core.utils.config_handling import ObjectContructionConfig
Expand All @@ -10,10 +10,10 @@
from ragbits.document_search.ingestion.providers.unstructured.pdf import UnstructuredPdfProvider

# TODO consider defining with some defined schema
ProvidersConfig = dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]
ProvidersConfig = Mapping[DocumentType, Callable[[], BaseProvider] | BaseProvider]


DEFAULT_PROVIDERS_CONFIG: ProvidersConfig = {
DEFAULT_PROVIDERS_CONFIG: MutableMapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] = {
DocumentType.TXT: UnstructuredDefaultProvider,
DocumentType.MD: UnstructuredDefaultProvider,
DocumentType.PDF: UnstructuredPdfProvider,
Expand Down Expand Up @@ -43,7 +43,7 @@ class DocumentProcessorRouter:
metadata such as the document type.
"""

def __init__(self, providers: dict[DocumentType, Callable[[], BaseProvider] | BaseProvider]):
def __init__(self, providers: ProvidersConfig):
self._providers = providers

@staticmethod
Expand Down Expand Up @@ -71,7 +71,7 @@ def from_dict_to_providers_config(dict_config: dict[str, ObjectContructionConfig
return providers_config

@classmethod
def from_config(cls, providers_config: ProvidersConfig | None = None) -> "DocumentProcessorRouter":
def from_config(cls, providers: ProvidersConfig | None = None) -> "DocumentProcessorRouter":
"""
Create a DocumentProcessorRouter from a configuration. If the configuration is not provided, the default
configuration will be used. If the configuration is provided, it will be merged with the default configuration,
Expand All @@ -83,14 +83,16 @@ def from_config(cls, providers_config: ProvidersConfig | None = None) -> "Docume
}
Args:
providers_config: The dictionary with the providers configuration, mapping the document types to the
providers: The dictionary with the providers configuration, mapping the document types to the
provider class.
Returns:
The DocumentProcessorRouter.
"""
config = copy.deepcopy(DEFAULT_PROVIDERS_CONFIG)
config.update(providers_config if providers_config is not None else {})
config: MutableMapping[DocumentType, Callable[[], BaseProvider] | BaseProvider] = copy.deepcopy(
DEFAULT_PROVIDERS_CONFIG
)
config.update(providers if providers is not None else {})

return cls(providers=config)

Expand Down
Loading

0 comments on commit d3b1e38

Please sign in to comment.