diff --git a/src/sqlalchemy_cratedb/type/vector.py b/src/sqlalchemy_cratedb/type/vector.py index 9cde7ea2..a46c406a 100644 --- a/src/sqlalchemy_cratedb/type/vector.py +++ b/src/sqlalchemy_cratedb/type/vector.py @@ -44,7 +44,7 @@ __all__ = ["FloatVector"] -def from_db(value: t.Iterable) -> t.Optional[npt.ArrayLike]: +def from_db(value: t.Iterable) -> t.Optional["npt.ArrayLike"]: import numpy as np # from `pgvector.utils` @@ -146,13 +146,13 @@ def __init__(self, dimensions: int = None): def as_generic(self): return sa.ARRAY - def bind_processor(self, dialect: sa.Dialect) -> t.Callable: + def bind_processor(self, dialect: sa.engine.Dialect) -> t.Callable: def process(value: t.Iterable) -> t.Optional[t.List]: return to_db(value, self.dimensions) return process - def result_processor(self, dialect: sa.Dialect, coltype: t.Any) -> t.Callable: + def result_processor(self, dialect: sa.engine.Dialect, coltype: t.Any) -> t.Callable: def process(value: t.Any) -> t.Optional[npt.ArrayLike]: return from_db(value)