From 1070b373fcb626b02c2dbebf81db0458e7aedc58 Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Thu, 21 Dec 2023 15:28:50 +0100 Subject: [PATCH] Vector: Fix type checking and compatibility with SQLAlchemy 1.x --- src/sqlalchemy_cratedb/type/vector.py | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 9cde7ea2..b78131cc 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -25,6 +25,8 @@ - The type implementation might want to be accompanied by corresponding support for the `KNN_MATCH` function, similar to what the dialect already offers for fulltext search through its `Match` predicate. +- After dropping support for SQLAlchemy 1.3, use + `class FloatVector(sa.TypeDecorator[t.Sequence[float]]):` ## Origin This module is based on the corresponding pgvector implementation @@ -44,7 +46,7 @@ __all__ = ["FloatVector"] -def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: +def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]: import numpy as np # from `pgvector.utils` @@ -77,8 +79,7 @@ def to_db(value: t.Any, dim: t.Optional[int] = None) -> t.Optional[t.List]: return value -class FloatVector(sa.TypeDecorator[t.Sequence[float]]): - +class FloatVector(sa.TypeDecorator): """ An improved implementation of the `FloatVector` data type for CrateDB, compared to the previous implementation on behalf of the LangChain adapter. @@ -146,14 +147,14 @@ def __init__(self, dimensions: int = None): def as_generic(self): return sa.ARRAY - def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable: def process(value: t.Iterable) -> t.Optional[t.List]: return to_db(value, self.dimensions) return process - def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: - def process(value: t.Any) -> t.Optional[npt.ArrayLike]: + def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable: + def process(value: t.Any) -> t.Optional["npt.ArrayLike"]: return from_db(value) return process