Skip to content

feat: support for Halfvec and Sparsevec vector types #226

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions langchain_postgres/v2/async_vectorstore.py
Original file line number Diff line number Diff line change
Expand Up @@ -935,15 +935,15 @@ async def aapply_vector_index(
text(f"CREATE EXTENSION IF NOT EXISTS {index.extension_name}")
)
await conn.commit()
function = index.get_index_function()

operator_class = index.operator_class()
filter = f"WHERE ({index.partial_indexes})" if index.partial_indexes else ""
params = "WITH " + index.index_options()
if name is None:
if index.name is None:
index.name = self.table_name + DEFAULT_INDEX_NAME_SUFFIX
name = index.name
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {function}) {params} {filter};'
stmt = f'CREATE INDEX {"CONCURRENTLY" if concurrently else ""} "{name}" ON "{self.schema_name}"."{self.table_name}" USING {index.index_type} ({self.embedding_column} {operator_class}) {params} {filter};'

if concurrently:
async with self.engine.connect() as conn:
Expand Down
16 changes: 15 additions & 1 deletion langchain_postgres/v2/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
from sqlalchemy.ext.asyncio import AsyncEngine, create_async_engine

from .hybrid_search_config import HybridSearchConfig
from .indexes import DEFAULT_VECTOR_TYPE, VectorType

T = TypeVar("T")

Expand Down Expand Up @@ -150,6 +151,7 @@ async def _ainit_vectorstore_table(
table_name: str,
vector_size: int,
*,
vector_type: VectorType = DEFAULT_VECTOR_TYPE,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
Expand All @@ -166,6 +168,8 @@ async def _ainit_vectorstore_table(
Args:
table_name (str): The database table name.
vector_size (int): Vector size for the embedding model to be used.
vector_type (VectorType): Type of the vector column to store embeddings.
Default: VectorType.VECTOR.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Expand Down Expand Up @@ -194,6 +198,8 @@ async def _ainit_vectorstore_table(
hybrid_search_default_column_name = content_column + "_tsv"
content_column = self._escape_postgres_identifier(content_column)
embedding_column = self._escape_postgres_identifier(embedding_column)
embedding_column_type = f"{vector_type.value}({vector_size})"

if metadata_columns is None:
metadata_columns = []
else:
Expand Down Expand Up @@ -246,7 +252,7 @@ async def _ainit_vectorstore_table(
query = f"""CREATE TABLE "{schema_name}"."{table_name}"(
"{id_column_name}" {id_data_type} PRIMARY KEY,
"{content_column}" TEXT NOT NULL,
"{embedding_column}" vector({vector_size}) NOT NULL
"{embedding_column}" {embedding_column_type} NOT NULL
{hybrid_search_column}"""
for column in metadata_columns:
if isinstance(column, Column):
Expand All @@ -268,6 +274,7 @@ async def ainit_vectorstore_table(
table_name: str,
vector_size: int,
*,
vector_type: VectorType = DEFAULT_VECTOR_TYPE,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
Expand All @@ -284,6 +291,8 @@ async def ainit_vectorstore_table(
Args:
table_name (str): The database table name.
vector_size (int): Vector size for the embedding model to be used.
vector_type (VectorType): Type of the vector column to store embeddings.
Default: VectorType.VECTOR.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Expand All @@ -308,6 +317,7 @@ async def ainit_vectorstore_table(
self._ainit_vectorstore_table(
table_name,
vector_size,
vector_type=vector_type,
schema_name=schema_name,
content_column=content_column,
embedding_column=embedding_column,
Expand All @@ -325,6 +335,7 @@ def init_vectorstore_table(
table_name: str,
vector_size: int,
*,
vector_type: VectorType = DEFAULT_VECTOR_TYPE,
schema_name: str = "public",
content_column: str = "content",
embedding_column: str = "embedding",
Expand All @@ -341,6 +352,8 @@ def init_vectorstore_table(
Args:
table_name (str): The database table name.
vector_size (int): Vector size for the embedding model to be used.
vector_type (VectorType): Type of the vector column to store embeddings.
Default: VectorType.VECTOR.
schema_name (str): The schema name.
Default: "public".
content_column (str): Name of the column to store document content.
Expand All @@ -365,6 +378,7 @@ def init_vectorstore_table(
self._ainit_vectorstore_table(
table_name,
vector_size,
vector_type=vector_type,
schema_name=schema_name,
content_column=content_column,
embedding_column=embedding_column,
Expand Down
40 changes: 34 additions & 6 deletions langchain_postgres/v2/indexes.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,21 +15,30 @@
class StrategyMixin:
operator: str
search_function: str
index_function: str
operator_class_suffix: str


class DistanceStrategy(StrategyMixin, enum.Enum):
"""Enumerator of the Distance strategies."""

EUCLIDEAN = "<->", "l2_distance", "vector_l2_ops"
COSINE_DISTANCE = "<=>", "cosine_distance", "vector_cosine_ops"
INNER_PRODUCT = "<#>", "inner_product", "vector_ip_ops"
EUCLIDEAN = "<->", "l2_distance", "l2_ops"
COSINE_DISTANCE = "<=>", "cosine_distance", "cosine_ops"
INNER_PRODUCT = "<#>", "inner_product", "ip_ops"


DEFAULT_DISTANCE_STRATEGY: DistanceStrategy = DistanceStrategy.COSINE_DISTANCE
DEFAULT_INDEX_NAME_SUFFIX: str = "langchainvectorindex"


class VectorType(enum.Enum):
VECTOR = "vector"
HALFVEC = "halfvec"
SPARSEVEC = "sparsevec"


DEFAULT_VECTOR_TYPE: VectorType = VectorType.VECTOR


def validate_identifier(identifier: str) -> None:
if re.match(r"^[a-zA-Z_][a-zA-Z0-9_]*$", identifier) is None:
raise ValueError(
Expand All @@ -47,6 +56,8 @@ class BaseIndex(ABC):
index_type (str): A string identifying the type of index. Defaults to "base".
distance_strategy (DistanceStrategy): The strategy used to calculate distances
between vectors in the index. Defaults to DistanceStrategy.COSINE_DISTANCE.
vector_type (VectorType): The type of vector column,
on which the index will be created. Defaults to VectorType.VECTOR
partial_indexes (Optional[list[str]]): A list of names of partial indexes. Defaults to None.
extension_name (Optional[str]): The name of the extension to be created for the index, if any. Defaults to None.
"""
Expand All @@ -56,6 +67,7 @@ class BaseIndex(ABC):
distance_strategy: DistanceStrategy = field(
default_factory=lambda: DistanceStrategy.COSINE_DISTANCE
)
vector_type: VectorType = DEFAULT_VECTOR_TYPE
partial_indexes: Optional[list[str]] = None
extension_name: Optional[str] = None

Expand All @@ -66,8 +78,11 @@ def index_options(self) -> str:
"index_options method must be implemented by subclass"
)

def get_index_function(self) -> str:
return self.distance_strategy.index_function
def operator_class(self) -> str:
"""Returns index operator class, based on vector type and distance strategy."""
return (
f"{self.vector_type.value}_{self.distance_strategy.operator_class_suffix}"
)

def __post_init__(self) -> None:
"""Check if initialization parameters are valid.
Expand Down Expand Up @@ -133,6 +148,19 @@ class IVFFlatIndex(BaseIndex):
index_type: str = "ivfflat"
lists: int = 100

def __post_init__(self) -> None:
"""Check if vector_type is valid.

Raises:
ValueError: if vector_type is SPARSEVEC
"""
super().__post_init__()

if self.vector_type is VectorType.SPARSEVEC:
raise ValueError(
"IVFFlatIndex does not support sparsevec, use VECTOR or HALFVEC instead"
)

def index_options(self) -> str:
"""Set index query options for vector store initialization."""
return f"(lists = {self.lists})"
Expand Down
31 changes: 29 additions & 2 deletions tests/unit_tests/v2/test_async_pg_vectorstore_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,15 +11,22 @@
from langchain_postgres import PGEngine
from langchain_postgres.v2.async_vectorstore import AsyncPGVectorStore
from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig
from langchain_postgres.v2.indexes import DistanceStrategy, HNSWIndex, IVFFlatIndex
from langchain_postgres.v2.indexes import (
DistanceStrategy,
HNSWIndex,
IVFFlatIndex,
VectorType,
)
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

uuid_str = str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE = "default" + uuid_str
DEFAULT_HYBRID_TABLE = "hybrid" + uuid_str
SIMPLE_TABLE = "simple" + uuid_str
HALFVEC_TABLE = "halfvec" + uuid_str

DEFAULT_INDEX_NAME = "index" + uuid_str
VECTOR_SIZE = 768
SIMPLE_TABLE = "default_table"

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)

Expand Down Expand Up @@ -57,6 +64,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
await engine._adrop_table(DEFAULT_TABLE)
await engine._adrop_table(DEFAULT_HYBRID_TABLE)
await engine._adrop_table(SIMPLE_TABLE)
await engine._adrop_table(HALFVEC_TABLE)
await engine.close()

@pytest_asyncio.fixture(scope="class")
Expand Down Expand Up @@ -94,6 +102,25 @@ async def test_aapply_vector_index(self, vs: AsyncPGVectorStore) -> None:
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
await vs.adrop_vector_index(DEFAULT_INDEX_NAME)

async def test_aapply_vector_index_halfvec(self, engine: PGEngine) -> None:
await engine._ainit_vectorstore_table(
HALFVEC_TABLE,
VECTOR_SIZE,
vector_type=VectorType.HALFVEC,
overwrite_existing=True,
)
vs = await AsyncPGVectorStore.create(
engine,
embedding_service=embeddings_service,
table_name=HALFVEC_TABLE,
)
await vs.aadd_texts(texts, ids=ids)
await vs.adrop_vector_index()
index = HNSWIndex(name=DEFAULT_INDEX_NAME, vector_type=VectorType.HALFVEC)
await vs.aapply_vector_index(index)
assert await vs.is_valid_index(DEFAULT_INDEX_NAME)
await vs.adrop_vector_index(DEFAULT_INDEX_NAME)

async def test_aapply_vector_index_non_hybrid_search_vs(
self, vs: AsyncPGVectorStore
) -> None:
Expand Down
39 changes: 39 additions & 0 deletions tests/unit_tests/v2/test_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,18 +12,21 @@

from langchain_postgres import Column, PGEngine
from langchain_postgres.v2.hybrid_search_config import HybridSearchConfig
from langchain_postgres.v2.indexes import VectorType
from tests.utils import VECTORSTORE_CONNECTION_STRING as CONNECTION_STRING

DEFAULT_TABLE = "default" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE = "custom" + str(uuid.uuid4()).replace("-", "_")
HYBRID_SEARCH_TABLE = "hybrid" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TYPEDDICT_TABLE = "custom_td" + str(uuid.uuid4()).replace("-", "_")
INT_ID_CUSTOM_TABLE = "custom_int_id" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_VECTOR_TYPE_TABLE = "custom_vt" + str(uuid.uuid4()).replace("-", "_")
DEFAULT_TABLE_SYNC = "default_sync" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TABLE_SYNC = "custom_sync" + str(uuid.uuid4()).replace("-", "_")
HYBRID_SEARCH_TABLE_SYNC = "hybrid_sync" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_TYPEDDICT_TABLE_SYNC = "custom_td_sync" + str(uuid.uuid4()).replace("-", "_")
INT_ID_CUSTOM_TABLE_SYNC = "custom_int_id_sync" + str(uuid.uuid4()).replace("-", "_")
CUSTOM_VECTOR_TYPE_TABLE_SYNC = "custom_vt_sync" + str(uuid.uuid4()).replace("-", "_")
VECTOR_SIZE = 768

embeddings_service = DeterministicFakeEmbedding(size=VECTOR_SIZE)
Expand Down Expand Up @@ -76,6 +79,7 @@ async def engine(self) -> AsyncIterator[PGEngine]:
await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_VECTOR_TYPE_TABLE}"')
await engine.close()

async def test_init_table(self, engine: PGEngine) -> None:
Expand Down Expand Up @@ -219,6 +223,22 @@ async def test_init_table_with_int_id(self, engine: PGEngine) -> None:
for row in results:
assert row in expected

async def test_init_table_custom_vector_type(self, engine: PGEngine) -> None:
await engine.ainit_vectorstore_table(
CUSTOM_VECTOR_TYPE_TABLE,
VECTOR_SIZE,
vector_type=VectorType.HALFVEC,
embedding_column="my_embedding",
)
stmt = (
"SELECT column_name, udt_name "
f"FROM information_schema.columns "
f"WHERE table_name = '{CUSTOM_VECTOR_TYPE_TABLE}' AND column_name = 'my_embedding';"
)

results = await afetch(engine, stmt)
assert results == [{"column_name": "my_embedding", "udt_name": "halfvec"}]

async def test_from_engine(self) -> None:
engine = create_async_engine(
CONNECTION_STRING,
Expand Down Expand Up @@ -264,6 +284,9 @@ async def engine(self) -> AsyncIterator[PGEngine]:
await aexecute(engine, f'DROP TABLE IF EXISTS "{DEFAULT_TABLE_SYNC}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{INT_ID_CUSTOM_TABLE_SYNC}"')
await aexecute(engine, f'DROP TABLE IF EXISTS "{CUSTOM_TYPEDDICT_TABLE_SYNC}"')
await aexecute(
engine, f'DROP TABLE IF EXISTS "{CUSTOM_VECTOR_TYPE_TABLE_SYNC}"'
)
await engine.close()

async def test_init_table(self, engine: PGEngine) -> None:
Expand Down Expand Up @@ -403,6 +426,22 @@ async def test_init_table_with_int_id(self, engine: PGEngine) -> None:
for row in results:
assert row in expected

async def test_init_table_custom_vector_type(self, engine: PGEngine) -> None:
engine.init_vectorstore_table(
CUSTOM_VECTOR_TYPE_TABLE_SYNC,
VECTOR_SIZE,
vector_type=VectorType.HALFVEC,
embedding_column="my_embedding",
)
stmt = (
"SELECT column_name, udt_name "
f"FROM information_schema.columns "
f"WHERE table_name = '{CUSTOM_VECTOR_TYPE_TABLE_SYNC}' AND column_name = 'my_embedding';"
)

results = await afetch(engine, stmt)
assert results == [{"column_name": "my_embedding", "udt_name": "halfvec"}]

async def test_engine_constructor_key(
self,
engine: PGEngine,
Expand Down
Loading