Skip to content

Refactor VectorStoreFactory to use registration functionality like StorageFactory #2006

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 7 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 4 additions & 0 deletions .semversioner/next-release/patch-20250722124545405651.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
{
"type": "patch",
"description": "Refactor VectorStoreFactory to use registration functionality like StorageFactory"
}
2 changes: 1 addition & 1 deletion graphrag/config/defaults.py
Original file line number Diff line number Diff line change
Expand Up @@ -239,7 +239,7 @@ class GlobalSearchDefaults:
class StorageDefaults:
"""Default values for storage."""

type = StorageType.file
type: StorageType = StorageType.file
base_dir: str = DEFAULT_OUTPUT_BASE_DIR
connection_string: None = None
container_name: None = None
Expand Down
18 changes: 7 additions & 11 deletions graphrag/storage/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,6 @@

from __future__ import annotations

from contextlib import suppress
from typing import TYPE_CHECKING, ClassVar

from graphrag.config.enums import StorageType
Expand All @@ -30,7 +29,6 @@ class StorageFactory:
"""

_storage_registry: ClassVar[dict[str, Callable[..., PipelineStorage]]] = {}
storage_types: ClassVar[dict[str, type]] = {} # For backward compatibility

@classmethod
def register(
Expand All @@ -41,18 +39,16 @@ def register(
Args:
storage_type: The type identifier for the storage.
creator: A callable that creates an instance of the storage.

Raises
------
TypeError: If creator is a class type instead of a factory function.
"""
if isinstance(creator, type):
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
raise TypeError(msg)
cls._storage_registry[storage_type] = creator

# For backward compatibility with code that may access storage_types directly
if (
callable(creator)
and hasattr(creator, "__annotations__")
and "return" in creator.__annotations__
):
with suppress(TypeError, KeyError):
cls.storage_types[storage_type] = creator.__annotations__["return"]

@classmethod
def create_storage(
cls, storage_type: StorageType | str, kwargs: dict
Expand Down
119 changes: 97 additions & 22 deletions graphrag/vector_stores/factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,13 +3,15 @@

"""A package containing a factory and supported vector store types."""

from __future__ import annotations

from enum import Enum
from typing import ClassVar
from typing import TYPE_CHECKING, ClassVar

if TYPE_CHECKING:
from collections.abc import Callable

from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore
from graphrag.vector_stores.base import BaseVectorStore
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore
from graphrag.vector_stores.lancedb import LanceDBVectorStore
from graphrag.vector_stores.base import BaseVectorStore


class VectorStoreType(str, Enum):
Expand All @@ -24,29 +26,102 @@ class VectorStoreFactory:
"""A factory for vector stores.

Includes a method for users to register a custom vector store implementation.

Configuration arguments are passed to each vector store implementation as kwargs
for individual enforcement of required/optional arguments.
"""

vector_store_types: ClassVar[dict[str, type]] = {}
_vector_store_registry: ClassVar[dict[str, Callable[..., BaseVectorStore]]] = {}

@classmethod
def register(cls, vector_store_type: str, vector_store: type):
"""Register a custom vector store implementation."""
cls.vector_store_types[vector_store_type] = vector_store
def register(
cls, vector_store_type: str, creator: Callable[..., BaseVectorStore]
) -> None:
"""Register a custom vector store implementation.

Args:
vector_store_type: The type identifier for the vector store.
creator: A callable that creates an instance of the vector store.

Raises
------
TypeError: If creator is a class type instead of a factory function.
"""
if isinstance(creator, type):
msg = "Registering classes directly is no longer supported. Please provide a factory function instead."
raise TypeError(msg)
cls._vector_store_registry[vector_store_type] = creator

@classmethod
def create_vector_store(
cls, vector_store_type: VectorStoreType | str, kwargs: dict
) -> BaseVectorStore:
"""Create or get a vector store from the provided type."""
match vector_store_type:
case VectorStoreType.LanceDB:
return LanceDBVectorStore(**kwargs)
case VectorStoreType.AzureAISearch:
return AzureAISearchVectorStore(**kwargs)
case VectorStoreType.CosmosDB:
return CosmosDBVectorStore(**kwargs)
case _:
if vector_store_type in cls.vector_store_types:
return cls.vector_store_types[vector_store_type](**kwargs)
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)
"""Create a vector store object from the provided type.

Args:
vector_store_type: The type of vector store to create.
kwargs: Additional keyword arguments for the vector store constructor.

Returns
-------
A BaseVectorStore instance.

Raises
------
ValueError: If the vector store type is not registered.
"""
vector_store_type_str = (
vector_store_type.value
if isinstance(vector_store_type, VectorStoreType)
else vector_store_type
)

if vector_store_type_str not in cls._vector_store_registry:
msg = f"Unknown vector store type: {vector_store_type}"
raise ValueError(msg)

return cls._vector_store_registry[vector_store_type_str](**kwargs)

@classmethod
def get_vector_store_types(cls) -> list[str]:
"""Get the registered vector store implementations."""
return list(cls._vector_store_registry.keys())

@classmethod
def is_supported_vector_store_type(cls, vector_store_type: str) -> bool:
"""Check if the given vector store type is supported."""
return vector_store_type in cls._vector_store_registry


# --- Factory functions for built-in vector stores ---


def create_lancedb_vector_store(**kwargs) -> BaseVectorStore:
"""Create a LanceDB vector store."""
from graphrag.vector_stores.lancedb import LanceDBVectorStore

return LanceDBVectorStore(**kwargs)


def create_azure_ai_search_vector_store(**kwargs) -> BaseVectorStore:
"""Create an Azure AI Search vector store."""
from graphrag.vector_stores.azure_ai_search import AzureAISearchVectorStore

return AzureAISearchVectorStore(**kwargs)


def create_cosmosdb_vector_store(**kwargs) -> BaseVectorStore:
"""Create a CosmosDB vector store."""
from graphrag.vector_stores.cosmosdb import CosmosDBVectorStore

return CosmosDBVectorStore(**kwargs)


# --- Register default implementations ---
VectorStoreFactory.register(VectorStoreType.LanceDB.value, create_lancedb_vector_store)
VectorStoreFactory.register(
VectorStoreType.AzureAISearch.value, create_azure_ai_search_vector_store
)
VectorStoreFactory.register(
VectorStoreType.CosmosDB.value, create_cosmosdb_vector_store
)
61 changes: 54 additions & 7 deletions tests/integration/storage/test_factory.py
Original file line number Diff line number Diff line change
Expand Up @@ -96,13 +96,60 @@ def test_get_storage_types():
assert StorageType.cosmosdb.value in storage_types


def test_backward_compatibility():
"""Test that the storage_types attribute is still accessible for backward compatibility."""
assert hasattr(StorageFactory, "storage_types")
# The storage_types attribute should be a dict
assert isinstance(StorageFactory.storage_types, dict)


def test_create_unknown_storage():
with pytest.raises(ValueError, match="Unknown storage type: unknown"):
StorageFactory.create_storage("unknown", {})


def test_register_class_directly_raises_error():
"""Test that registering a class directly raises a TypeError."""
import re
from collections.abc import Iterator
from typing import Any

from graphrag.storage.pipeline_storage import PipelineStorage

class CustomStorage(PipelineStorage):
def __init__(self, **kwargs):
pass

def find(
self,
file_pattern: re.Pattern[str],
base_dir: str | None = None,
file_filter: dict[str, Any] | None = None,
max_count=-1,
) -> Iterator[tuple[str, dict[str, Any]]]:
return iter([])

async def get(
self, key: str, as_bytes: bool | None = None, encoding: str | None = None
) -> Any:
return None

async def set(self, key: str, value: Any, encoding: str | None = None) -> None:
pass

async def delete(self, key: str) -> None:
pass

async def has(self, key: str) -> bool:
return False

async def clear(self) -> None:
pass

def child(self, name: str | None) -> "PipelineStorage":
return self

def keys(self) -> list[str]:
return []

async def get_creation_date(self, key: str) -> str:
return "2024-01-01 00:00:00 +0000"

# Attempting to register a class directly should raise TypeError
with pytest.raises(
TypeError, match="Registering classes directly is no longer supported"
):
StorageFactory.register("custom_class", CustomStorage)
Loading
Loading