Skip to content

Commit

Permalink
Improve SQLAlchemy FloatVector type implementation
Browse files Browse the repository at this point in the history
  • Loading branch information
amotl committed Dec 21, 2023
1 parent 4c8c2c4 commit 5c30efc
Show file tree
Hide file tree
Showing 2 changed files with 13 additions and 30 deletions.
2 changes: 1 addition & 1 deletion target_cratedb/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -137,7 +137,7 @@ def pick_individual_type(jsonschema_type: dict):
if "type" in storage_properties and storage_properties["type"] == "vector":

Check warning on line 137 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L136-L137

Added lines #L136 - L137 were not covered by tests
# On PostgreSQL/pgvector, use the corresponding type definition
# from its SQLAlchemy dialect.
return FloatVector(storage_properties["dim"])
return FloatVector(dimensions=storage_properties["dim"])

Check warning on line 140 in target_cratedb/connector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/connector.py#L140

Added line #L140 was not covered by tests

# Discover/translate inner types.
inner_type = resolve_array_inner_type(jsonschema_type)
Expand Down
41 changes: 12 additions & 29 deletions target_cratedb/sqlalchemy/vector.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
import sqlalchemy as sa
from crate.client.sqlalchemy.compiler import CrateTypeCompiler
from crate.client.sqlalchemy.dialect import TYPES_MAP
from sqlalchemy import TypeDecorator
from sqlalchemy.sql import sqltypes
from sqlalchemy.sql.type_api import TypeEngine

__all__ = ["FloatVector"]

Expand Down Expand Up @@ -41,7 +41,8 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]:
return value

Check warning on line 41 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L41

Added line #L41 was not covered by tests


class FloatVector(TypeEngine[t.Sequence[t.Any]]):
class FloatVector(TypeDecorator[t.Sequence[float]]):

"""
An improved implementation of the `FloatVector` data type for CrateDB,
compared to the previous implementation on behalf of the LangChain adapter.
Expand Down Expand Up @@ -96,32 +97,25 @@ class FloatVector(TypeEngine[t.Sequence[t.Any]]):
implementations correspondingly.
"""

cache_ok = True
cache_ok = False

__visit_name__ = "FLOAT_VECTOR"

_is_array = True

zero_indexes = False

def __init__(self, dim: t.Optional[int] = None, as_tuple: bool = False) -> None:
self.dim = dim
self.as_tuple = as_tuple

@property
def hashable(self):
return self.as_tuple
impl = sa.ARRAY

@property
def python_type(self):
return list
def __init__(self, dimensions: int = None):
super().__init__(sa.FLOAT, dimensions=dimensions)

Check warning on line 111 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L111

Added line #L111 was not covered by tests

def as_generic(self):
return sqltypes.ARRAY

Check warning on line 114 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L114

Added line #L114 was not covered by tests

def bind_processor(self, dialect: sa.Dialect) -> t.Callable:
def process(value: t.Iterable) -> t.Optional[t.List]:
return to_db(value, self.dim)
return to_db(value, self.dimensions)

Check warning on line 118 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L117-L118

Added lines #L117 - L118 were not covered by tests

return process

Check warning on line 120 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L120

Added line #L120 was not covered by tests

Expand All @@ -131,27 +125,16 @@ def process(value: t.Any) -> t.Optional[npt.ArrayLike]:

return process

Check warning on line 126 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L126

Added line #L126 was not covered by tests

"""
CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`.
-- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55
On the other hand, pgvector use a comparator to apply different similarity functions as operators,
see `pgvector.sqlalchemy.Vector.comparator_factory`.
<->: l2/euclidean_distance
<#>: max_inner_product
<=>: cosine_distance
TODO: Discuss.
""" # noqa: E501


# Accompanies the type definition for reverse type lookups.
TYPES_MAP["float_vector"] = FloatVector


def visit_FLOAT_VECTOR(self, type_, **kw):
return f"FLOAT_VECTOR({type_.dim})"
dimensions = type_.dimensions
if dimensions is None:
raise ValueError("FloatVector must be initialized with dimension size")
return f"FLOAT_VECTOR({dimensions})"

Check warning on line 137 in target_cratedb/sqlalchemy/vector.py

View check run for this annotation

Codecov / codecov/patch

target_cratedb/sqlalchemy/vector.py#L134-L137

Added lines #L134 - L137 were not covered by tests


CrateTypeCompiler.visit_FLOAT_VECTOR = visit_FLOAT_VECTOR

0 comments on commit 5c30efc

Please sign in to comment.