diff --git a/target_cratedb/connector.py b/target_cratedb/connector.py index a9d60f3..ec62cab 100644 --- a/target_cratedb/connector.py +++ b/target_cratedb/connector.py @@ -137,7 +137,7 @@ def pick_individual_type(jsonschema_type: dict): if "type" in storage_properties and storage_properties["type"] == "vector": # On PostgreSQL/pgvector, use the corresponding type definition # from its SQLAlchemy dialect. - return FloatVector(storage_properties["dim"]) + return FloatVector(dimensions=storage_properties["dim"]) # Discover/translate inner types. inner_type = resolve_array_inner_type(jsonschema_type) diff --git a/target_cratedb/sqlalchemy/vector.py b/target_cratedb/sqlalchemy/vector.py index e47f8df..1fc1287 100644 --- a/target_cratedb/sqlalchemy/vector.py +++ b/target_cratedb/sqlalchemy/vector.py @@ -6,8 +6,8 @@ import sqlalchemy as sa from crate.client.sqlalchemy.compiler import CrateTypeCompiler from crate.client.sqlalchemy.dialect import TYPES_MAP +from sqlalchemy import TypeDecorator from sqlalchemy.sql import sqltypes -from sqlalchemy.sql.type_api import TypeEngine __all__ = ["FloatVector"] @@ -41,7 +41,8 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: return value -class FloatVector(TypeEngine[t.Sequence[t.Any]]): +class FloatVector(TypeDecorator[t.Sequence[float]]): + """ An improved implementation of the `FloatVector` data type for CrateDB, compared to the previous implementation on behalf of the LangChain adapter. @@ -96,7 +97,7 @@ class FloatVector(TypeEngine[t.Sequence[t.Any]]): implementations correspondingly. """ - cache_ok = True + cache_ok = False __visit_name__ = "FLOAT_VECTOR" @@ -104,24 +105,17 @@ class FloatVector(TypeEngine[t.Sequence[t.Any]]): zero_indexes = False - def __init__(self, dim: t.Optional[int] = None, as_tuple: bool = False) -> None: - self.dim = dim - self.as_tuple = as_tuple - - @property - def hashable(self): - return self.as_tuple + impl = sa.ARRAY - @property - def python_type(self): - return list + def __init__(self, dimensions: int = None): + super().__init__(sa.FLOAT, dimensions=dimensions) def as_generic(self): return sqltypes.ARRAY def bind_processor(self, dialect: sa.Dialect) -> t.Callable: def process(value: t.Iterable) -> t.Optional[t.List]: - return to_db(value, self.dim) + return to_db(value, self.dimensions) return process @@ -131,27 +125,16 @@ def process(value: t.Any) -> t.Optional[npt.ArrayLike]: return process - """ - CrateDB currently only supports the similarity function `VectorSimilarityFunction.EUCLIDEAN`. - -- https://github.com/crate/crate/blob/1ca5c6dbb2/server/src/main/java/io/crate/types/FloatVectorType.java#L55 - - On the other hand, pgvector use a comparator to apply different similarity functions as operators, - see `pgvector.sqlalchemy.Vector.comparator_factory`. - - <->: l2/euclidean_distance - <#>: max_inner_product - <=>: cosine_distance - - TODO: Discuss. - """ # noqa: E501 - # Accompanies the type definition for reverse type lookups. TYPES_MAP["float_vector"] = FloatVector def visit_FLOAT_VECTOR(self, type_, **kw): - return f"FLOAT_VECTOR({type_.dim})" + dimensions = type_.dimensions + if dimensions is None: + raise ValueError("FloatVector must be initialized with dimension size") + return f"FLOAT_VECTOR({dimensions})" CrateTypeCompiler.visit_FLOAT_VECTOR = visit_FLOAT_VECTOR