diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 9633d72..69985d7 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -35,7 +35,7 @@ jobs: runs-on: ubuntu-latest services: tidb: - image: wangdi4zm/tind:v7.5.3-vector-index + image: wangdi4zm/tind:v8.4.0-vector-index ports: - 4000:4000 steps: diff --git a/.gitignore b/.gitignore index e160b92..0e64d13 100644 --- a/.gitignore +++ b/.gitignore @@ -112,6 +112,9 @@ venv/ ENV/ env.bak/ venv.bak/ +bin/ +pyvenv.cfg +share/ # Spyder project settings .spyderproject @@ -137,6 +140,7 @@ dmypy.json # Cython debug symbols cython_debug/ +.vscode/ .idea/ django_tests_dir diff --git a/README.md b/README.md index 4d26f5b..4c315e5 100644 --- a/README.md +++ b/README.md @@ -35,23 +35,26 @@ Learn how to connect to TiDB Serverless in the [TiDB Cloud documentation](https: Define table with vector field ```python -from sqlalchemy import Column, Integer, create_engine -from sqlalchemy.orm import declarative_base -from tidb_vector.sqlalchemy import VectorType +from sqlalchemy import Column, Integer, create_engine, func +from sqlalchemy.orm import Session +from tidb_vector.sqlalchemy import VectorType, VectorIndex, get_declarative_base -engine = create_engine('mysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test') -Base = declarative_base() +engine = create_engine('tidb+pymysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test') +Base = get_declarative_base() -class Test(Base): - __tablename__ = 'test' +class Document(Base): + __tablename__ = 'sqlalchemy_demo_documents' id = Column(Integer, primary_key=True) embedding = Column(VectorType(3)) -# or add hnsw index when creating table -class TestWithIndex(Base): - __tablename__ = 'test_with_index' +# or add hnsw index +class DocumentWithIndex(Base): + __tablename__ = 'sqlalchemy_demo_documents_with_index' id = Column(Integer, primary_key=True) - embedding = Column(VectorType(3), comment="hnsw(distance=l2)") + embedding = Column(VectorType(3)) + __table_args__ = ( + VectorIndex('idx_l2', text('(vec_l2_distance(embedding))')), + ) Base.metadata.create_all(engine) ``` @@ -59,9 +62,10 @@ Base.metadata.create_all(engine) Insert vector data ```python -test = Test(embedding=[1, 2, 3]) -session.add(test) -session.commit() +with Session(engine) as session: + test = Test(embedding=[1, 2, 3]) + session.add(test) + session.commit() ``` Get the nearest neighbors @@ -252,4 +256,4 @@ There are some examples to show how to use the tidb-vector-python to interact wi for more examples, see the [examples](./examples) directory. ## Contributing -Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file. \ No newline at end of file +Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file. diff --git a/examples/orm-sqlalchemy-quickstart/.env.example b/examples/orm-sqlalchemy-quickstart/.env.example index 5c55f22..c57338e 100644 --- a/examples/orm-sqlalchemy-quickstart/.env.example +++ b/examples/orm-sqlalchemy-quickstart/.env.example @@ -1 +1 @@ -TIDB_DATABASE_URL=mysql+pymysql://:@:4000/?ssl_ca=&ssl_verify_cert=true&ssl_verify_identity=true \ No newline at end of file +TIDB_DATABASE_URL=tidb+pymysql://:@:4000/?ssl_ca=&ssl_verify_cert=true&ssl_verify_identity=true diff --git a/examples/orm-sqlalchemy-quickstart/README.md b/examples/orm-sqlalchemy-quickstart/README.md index 40b96aa..8bbe06d 100644 --- a/examples/orm-sqlalchemy-quickstart/README.md +++ b/examples/orm-sqlalchemy-quickstart/README.md @@ -40,7 +40,7 @@ cp .env.example .env Copy the `HOST`, `PORT`, `USERNAME`, `PASSWORD`, `DATABASE`, and `CA` parameters from the TiDB Cloud console (see [Prerequisites](../README.md#prerequisites)), and then replace the placeholders in the `.env` file. ```bash -TIDB_DATABASE_URL=mysql+pymysql://:@:4000/?ssl_ca=&ssl_verify_cert=true&ssl_verify_identity=true +TIDB_DATABASE_URL=tidb+pymysql://:@:4000/?ssl_ca=&ssl_verify_cert=true&ssl_verify_identity=true ``` ### Run this example @@ -59,4 +59,4 @@ Get documents within a certain distance: document: fish - distance: 0.12712843905603044 document: dog -``` \ No newline at end of file +``` diff --git a/examples/orm-sqlalchemy-quickstart/sqlalchemy-quickstart.py b/examples/orm-sqlalchemy-quickstart/sqlalchemy-quickstart.py index 1883a59..4bf5a25 100644 --- a/examples/orm-sqlalchemy-quickstart/sqlalchemy-quickstart.py +++ b/examples/orm-sqlalchemy-quickstart/sqlalchemy-quickstart.py @@ -1,9 +1,9 @@ import os import dotenv -from sqlalchemy import Column, Integer, create_engine, Text -from sqlalchemy.orm import declarative_base, Session -from tidb_vector.sqlalchemy import VectorType +from sqlalchemy import Column, Integer, create_engine, Text, text +from sqlalchemy.orm import Session +from tidb_vector.sqlalchemy import VectorType, VectorIndex, get_declarative_base dotenv.load_dotenv() @@ -12,7 +12,7 @@ engine = create_engine(tidb_connection_string) # Step 2: Define a table with a vector column. -Base = declarative_base() +Base = get_declarative_base() class Document(Base): @@ -22,18 +22,19 @@ class Document(Base): embedding = Column(VectorType(3)) -# Or add HNSW index when creating table. +# Or add HNSW index class DocumentWithIndex(Base): __tablename__ = 'sqlalchemy_demo_documents_with_index' id = Column(Integer, primary_key=True) content = Column(Text) - embedding = Column(VectorType(3), comment="hnsw(distance=cosine)") - + embedding = Column(VectorType(3)) + __table_args__ = ( + VectorIndex('idx_cos', text('(vec_cosine_distance(embedding))')), + ) Base.metadata.drop_all(engine) Base.metadata.create_all(engine) - # Step 3: Insert embeddings into the table. with Session(engine) as session: session.add(Document(content="dog", embedding=[1, 2, 1])) diff --git a/tests/integrations/test_utils.py b/tests/integrations/test_utils.py index 502148d..5b8a837 100644 --- a/tests/integrations/test_utils.py +++ b/tests/integrations/test_utils.py @@ -1,4 +1,5 @@ """Test TiDB Vector Search functionality.""" + from __future__ import annotations from tidb_vector.integrations.utils import extract_info_from_column_definition diff --git a/tests/integrations/test_vector_client.py b/tests/integrations/test_vector_client.py index 43e438a..6039696 100644 --- a/tests/integrations/test_vector_client.py +++ b/tests/integrations/test_vector_client.py @@ -1,4 +1,5 @@ """Test TiDB Vector Search functionality.""" + from __future__ import annotations from typing import List, Tuple diff --git a/tests/sqlalchemy/test_sqlalchemy.py b/tests/sqlalchemy/test_sqlalchemy.py index 134e452..5a5db5b 100644 --- a/tests/sqlalchemy/test_sqlalchemy.py +++ b/tests/sqlalchemy/test_sqlalchemy.py @@ -1,27 +1,35 @@ import pytest import numpy as np -from sqlalchemy import URL, create_engine, Column, Integer, select -from sqlalchemy.orm import declarative_base, sessionmaker +import sqlalchemy +from sqlalchemy import URL, create_engine, Column, Integer, select, Index +from sqlalchemy.orm import sessionmaker from sqlalchemy.exc import OperationalError -from tidb_vector.sqlalchemy import VectorType +from tidb_vector.sqlalchemy import ( + VectorType, + VectorIndex, + TiFlashReplica, + get_declarative_base, +) from ..config import TestConfig - +database_name = "test" db_url = URL( - "mysql+pymysql", + "tidb+pymysql", username=TestConfig.TIDB_USER, password=TestConfig.TIDB_PASSWORD, host=TestConfig.TIDB_HOST, port=TestConfig.TIDB_PORT, - database="test", - query={"ssl_verify_cert": True, "ssl_verify_identity": True} - if TestConfig.TIDB_SSL - else {}, + database=database_name, + query=( + {"ssl_verify_cert": True, "ssl_verify_identity": True} + if TestConfig.TIDB_SSL + else {} + ), ) engine = create_engine(db_url) Session = sessionmaker(bind=engine) -Base = declarative_base() +Base = get_declarative_base() class Item1Model(Base): @@ -303,3 +311,182 @@ def test_negative_inner_product(self): ) assert len(items) == 2 assert items[1].distance == -14.0 + + +class TestSQLAlchemyDDL: + def setup_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + Item2Model.__table__.create(bind=engine) + + def teardown_class(self): + Item2Model.__table__.drop(bind=engine, checkfirst=True) + + def setup_method(self): + with Session() as session: + session.query(Item2Model).delete() + session.commit() + + @staticmethod + def check_indexes(table, expect_indexes_name): + indexes = table.indexes + indexes_name = [index.name for index in indexes] + assert len(indexes) == len(expect_indexes_name) + for i in expect_indexes_name: + assert i in indexes_name + + def test_alter_tiflash_replica(self): + # Add tiflash replica + replica = TiFlashReplica(Item2Model.__table__, num=1) + replica.create(engine) + sql = sqlalchemy.text( + f""" + SELECT TABLE_SCHEMA,TABLE_NAME,REPLICA_COUNT + FROM INFORMATION_SCHEMA.TIFLASH_REPLICA + WHERE TABLE_SCHEMA="{database_name}" AND TABLE_NAME="{Item2Model.__tablename__}" + """ + ) + with Session() as session: + rs = session.execute(sql) + for r in rs: + assert r.TABLE_SCHEMA == database_name + assert r.TABLE_NAME == Item2Model.__tablename__ + assert r.REPLICA_COUNT == 1 + assert Item2Model.__table__.info["has_tiflash_replica"] + + # Drop tiflash replica + replica.drop(engine) + with Session() as session: + rs = session.execute(sql) + for r in rs: + assert r.TABLE_SCHEMA == database_name + assert r.TABLE_NAME == Item2Model.__tablename__ + assert r.REPLICA_COUNT == 0 + assert not Item2Model.__table__.info.get("has_tiflash_replica", False) + + def test_query_with_index(self): + # indexes + l2_index = VectorIndex( + "idx_embedding_l2", + sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding), + ) + l2_index.create(engine) + cos_index = VectorIndex( + "idx_embedding_cos", + sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding), + ) + cos_index.create(engine) + + self.check_indexes( + Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"] + ) + + with Session() as session: + session.add_all( + [Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item2Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item2Model).filter( + Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item2Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 + + # drop indexes + l2_index.drop(engine) + cos_index.drop(engine) + + def test_query_with_inited_index(self): + class Item3Model(Base): + __tablename__ = "sqlalchemy_item3" + id = Column(Integer, primary_key=True) + embedding = Column(VectorType(dim=3)) + __table_args__ = ( + Index("idx_id", "id"), + VectorIndex( + "idx_embedding_l2", sqlalchemy.text("(vec_l2_distance(embedding))") + ), + VectorIndex( + "idx_embedding_cos", + sqlalchemy.text("(vec_cosine_distance(embedding))"), + ), + ) + + Item3Model.__table__.drop(bind=engine, checkfirst=True) + Item3Model.__table__.create(bind=engine) + + self.check_indexes( + Item3Model.__table__, ["idx_id", "idx_embedding_l2", "idx_embedding_cos"] + ) + + with Session() as session: + session.add_all( + [Item3Model(embedding=[1, 2, 3]), Item3Model(embedding=[1, 2, 3.2])] + ) + session.commit() + + # l2 distance + result_l2 = session.scalars( + select(Item3Model).filter( + Item3Model.embedding.l2_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_l2) == 2 + + distance_l2 = Item3Model.embedding.l2_distance([1, 2, 3]) + items_l2 = ( + session.query(Item3Model.id, distance_l2.label("distance")) + .order_by(distance_l2) + .limit(5) + .all() + ) + assert len(items_l2) == 2 + assert items_l2[0].distance == 0.0 + + # cosine distance + result_cos = session.scalars( + select(Item3Model).filter( + Item3Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2 + ) + ).all() + assert len(result_cos) == 2 + + distance_cos = Item3Model.embedding.cosine_distance([1, 2, 3]) + items_cos = ( + session.query(Item3Model.id, distance_cos.label("distance")) + .order_by(distance_cos) + .limit(5) + .all() + ) + assert len(items_cos) == 2 + assert items_cos[0].distance == 0.0 diff --git a/tidb_vector/sqlalchemy/__init__.py b/tidb_vector/sqlalchemy/__init__.py index 6aed129..4ae33f4 100644 --- a/tidb_vector/sqlalchemy/__init__.py +++ b/tidb_vector/sqlalchemy/__init__.py @@ -1,8 +1,27 @@ from sqlalchemy.types import UserDefinedType from sqlalchemy.sql import func +from sqlalchemy.dialects import registry as _registry from tidb_vector.constants import MAX_DIMENSION_LENGTH, MIN_DIMENSION_LENGTH from tidb_vector.utils import decode_vector, encode_vector +from .ddl import VectorIndex, TiFlashReplica, Table, MetaData +from .ext.declarative import get_declarative_base + +_registry.register( + "tidb", + "tidb_vector.sqlalchemy.dialect", + "dialect_pymysql", +) +_registry.register( + "tidb.mysqldb", + "tidb_vector.sqlalchemy.dialect", + "dialect_mysqldb", +) +_registry.register( + "tidb.pymysql", + "tidb_vector.sqlalchemy.dialect", + "dialect_pymysql", +) class VectorType(UserDefinedType): @@ -16,7 +35,7 @@ def __init__(self, dim=None): if dim is not None and not isinstance(dim, int): raise ValueError("expected dimension to be an integer or None") - # tidb vector dimention length has limitation + # tidb vector dimension length has limitation if dim is not None and ( dim < MIN_DIMENSION_LENGTH or dim > MAX_DIMENSION_LENGTH ): @@ -80,3 +99,13 @@ def negative_inner_product(self, other): return func.VEC_NEGATIVE_INNER_PRODUCT(self, formatted_other).label( "negative_inner_product" ) + + +__all__ = ( + "get_declarative_base", + "MetaData", + "Table", + "VectorType", + "VectorIndex", + "TiFlashReplica", +) diff --git a/tidb_vector/sqlalchemy/compiler.py b/tidb_vector/sqlalchemy/compiler.py new file mode 100644 index 0000000..01bb50f --- /dev/null +++ b/tidb_vector/sqlalchemy/compiler.py @@ -0,0 +1,57 @@ +from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler +from sqlalchemy.sql import elements, operators, functions + +from .ddl import AlterTiFlashReplica, CreateVectorIndex + + +class TiDBDDLCompiler(MySQLDDLCompiler): + def visit_alter_tiflash_replica(self, replica: AlterTiFlashReplica, **kw): + # from IPython import embed;embed() + return "ALTER TABLE {} SET TIFLASH REPLICA {}".format( + replica.element.inner_table.name, replica.new_num + ) + + def visit_create_vector_index(self, create: CreateVectorIndex, **kw): + """Build the ``CREATE VECTOR INDEX ...`` statement + MySQLDDLCompiler.visit_create_index + """ + index = create.element + self._verify_index_table(index) + preparer = self.preparer + table = preparer.format_table(index.table) + + columns = [ + self.sql_compiler.process( + ( + elements.Grouping(expr) + if ( + isinstance(expr, elements.BinaryExpression) + or ( + isinstance(expr, elements.UnaryExpression) + and expr.modifier + not in (operators.desc_op, operators.asc_op) + ) + or isinstance(expr, functions.FunctionElement) + ) + else expr + ), + include_table=False, + literal_binds=True, + ) + for expr in index.expressions + ] + + name = self._prepared_index_name(index) + + text = "CREATE VECTOR INDEX " + if create.if_not_exists: + text += "IF NOT EXISTS " + text += f"{name} ON {table} " + + text += f"({', '.join(columns)})" + + using = index.kwargs.get(f"{self.dialect.name}_using") + if using is not None: + text += f" USING {preparer.quote(using)}" + + return text diff --git a/tidb_vector/sqlalchemy/ddl.py b/tidb_vector/sqlalchemy/ddl.py new file mode 100644 index 0000000..78c7922 --- /dev/null +++ b/tidb_vector/sqlalchemy/ddl.py @@ -0,0 +1,259 @@ +from typing import Any, Optional, Sequence, Union + +import sqlalchemy +from sqlalchemy.sql.ddl import SchemaGenerator as SchemaGeneratorBase +from sqlalchemy.sql.ddl import SchemaDropper as SchemaDropperBase +from sqlalchemy.sql.base import DialectKWArgs +from sqlalchemy.sql.schema import ( + ColumnCollectionMixin, + HasConditionalDDL, + SchemaItem, + MetaData as MetaDataBase, + Table as TableBase, + ColumnElement, +) +import sqlalchemy.exc as exc + + +class MetaData(MetaDataBase): + """ + A collection of :class:`.Table` objects and their associated schema constructs. + Overwrites the default implementation to use :class:`TiDBSchemaGenerator` and :class:`TiDBSchemaDropper`. + """ + + def create_all(self, bind, tables=None, checkfirst: bool = True) -> None: + bind._run_ddl_visitor( + TiDBSchemaGenerator, self, checkfirst=checkfirst, tables=tables + ) + + def drop_all(self, bind, tables=None, checkfirst: bool = True) -> None: + bind._run_ddl_visitor( + TiDBSchemaDropper, self, checkfirst=checkfirst, tables=tables + ) + + +class Table(TableBase): + """ + Represent a table in a database. + Overwrites the default implementation to use :class:`TiDBSchemaGenerator` and :class:`TiDBSchemaDropper`. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def create(self, bind, checkfirst: bool = False) -> None: + """Issue a ``CREATE`` statement for this + :class:`_schema.Table`, using the given + :class:`.Connection` or :class:`.Engine` + for connectivity. + + Overwrites the default implementation to use :class:`TiDBSchemaGenerator` + """ + bind._run_ddl_visitor(TiDBSchemaGenerator, self, checkfirst=checkfirst) + + def drop(self, bind, checkfirst: bool = False) -> None: + """Issue a ``DROP`` statement for this + :class:`_schema.Table`, using the given + :class:`.Connection` or :class:`.Engine` for connectivity. + + Overwrites the default implementation to use :class:`TiDBSchemaDropper` + + """ + bind._run_ddl_visitor(TiDBSchemaDropper, self, checkfirst=checkfirst) + + +class TiFlashReplica(DialectKWArgs, SchemaItem): + """Represent the tiflash replica table attribute""" + + __visit_name__ = "tiflash_replica" + + inner_table: Optional[Table] + + @property + def bind(self): + return self.metadata.bind + + @property + def metadata(self): + return self.inner_table.metadata + + def __init__(self, inner_table: Table, num=1, *args, **kwargs) -> None: + super().__init__() + self.inner_table = inner_table + self.replica_num = num + # set the metadata to the inner_table + self.inner_table.info["has_tiflash_replica"] = True + + def create(self, bind=None): + """Issue a ``ALTER TABLE ... SET TIFLASH REPLICA {num}`` statement""" + if bind is None: + bind = self.bind + bind._run_ddl_visitor(TiDBSchemaGenerator, self) + + def drop(self, bind=None): + """Issue a ``ALTER TABLE ... SET TIFLASH REPLICA 0`` statement""" + if bind is None: + bind = self.bind + bind._run_ddl_visitor(TiDBSchemaDropper, self) + self.replica_num = 0 + + +class VectorIndex(DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem): + __visit_name__ = "vector_index" + + table: Optional[Table] + expressions: Sequence[Union[str, ColumnElement[Any]]] + _table_bound_expressions: Sequence[ColumnElement[Any]] + + @property + def bind(self): + return self.metadata.bind + + @property + def metadata(self): + return self.table.metadata + + def __init__( + self, + name: Optional[str], + expressions, + _table: Optional[Table] = None, + ) -> None: + super().__init__() + self.table = table = None + if _table is not None: + table = _table + + self.name = name + + self.expressions = [] + # will call _set_parent() if table-bound column + # objects are present + ColumnCollectionMixin.__init__( + self, + expressions, + _column_flag=False, + _gather_expressions=self.expressions, + ) + if table is not None: + self._set_parent(table) + + def _set_parent(self, parent, **kw: Any) -> None: + table = parent + assert isinstance(table, Table) + ColumnCollectionMixin._set_parent(self, table) + + if self.table is not None and table is not self.table: + raise exc.ArgumentError( + f"Index '{self.name}' is against table " + f"'{self.table.description}', and " + f"cannot be associated with table '{table.description}'." + ) + self.table = table + table.indexes.add(self) + + expressions = self.expressions + col_expressions = self._col_expressions(table) + assert len(expressions) == len(col_expressions) + + exprs = [] + for expr, colexpr in zip(expressions, col_expressions): + if isinstance(expr, sqlalchemy.sql.ClauseElement): + exprs.append(expr) + elif colexpr is not None: + exprs.append(colexpr) + else: + assert False + self.expressions = self._table_bound_expressions = exprs + + def create(self, bind, checkfirst: bool = False) -> None: + """Issue a ``CREATE`` statement for this + :class:`.VectorIndex`, using the given + :class:`.Connection` or :class:`.Engine`` for connectivity. + """ + if bind is None: + bind = self.bind + bind._run_ddl_visitor(TiDBSchemaGenerator, self, checkfirst=checkfirst) + + def drop(self, bind, checkfirst: bool = False) -> None: + """Issue a ``DROP`` statement for this + :class:`.VectorIndex`, using the given + :class:`.Connection` or :class:`.Engine` for connectivity. + """ + if bind is None: + bind = self.bind + bind._run_ddl_visitor(TiDBSchemaDropper, self, checkfirst=checkfirst) + + +class AlterTiFlashReplica(sqlalchemy.sql.ddl._CreateDropBase): + """Represent a ``ALTER TABLE ... SET TIFLASH REPLICA ...`` statement.""" + + __visit_name__: str = "alter_tiflash_replica" + + def __init__(self, element, new_num): + super().__init__(element) + self.new_num = new_num + + +class CreateVectorIndex(sqlalchemy.sql.ddl.CreateIndex): + """Represent a ``CREATE VECTOR INDEX ... ON ...`` statement.""" + + __visit_name__: str = "create_vector_index" + + def __init__(self, element, if_not_exists=False): + super().__init__(element, if_not_exists) + + +class TiDBSchemaGenerator(SchemaGeneratorBase): + """Building logical CERATE ... statements.""" + + def __init__(self, dialect, connection, checkfirst=False, tables=None): + super().__init__(dialect, connection, checkfirst, tables) + + def visit_tiflash_replica(self, replica, **kwargs): + """ + replica: TiFlashReplica + """ + with self.with_ddl_events(replica): + AlterTiFlashReplica( + replica, new_num=replica.replica_num, **kwargs + )._invoke_with(self.connection) + + def visit_vector_index(self, index, create_ok=False): + """ + index: VectorIndex + """ + if not create_ok and not self._can_create_index(index): + return + with self.with_ddl_events(index): + # Automatically add tiflash replica if not exist + if not index.table.info.get("has_tiflash_replica", False): + num_replica = 1 # by default 1 replica + replica = TiFlashReplica(index.table, num_replica) + AlterTiFlashReplica(replica, new_num=num_replica)._invoke_with( + self.connection + ) + # Create the vector index + CreateVectorIndex(index)._invoke_with(self.connection) + + +class TiDBSchemaDropper(SchemaDropperBase): + """Building logical DROP ... statements.""" + + def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs): + super().__init__(dialect, connection, checkfirst, tables, **kwargs) + + def visit_tiflash_replica(self, replica, **kwargs): + with self.with_ddl_events(replica): + # set the replica_num to new_num + AlterTiFlashReplica(replica, new_num=0)._invoke_with(self.connection) + # reset the table.info of has_tiflash_replica + del replica.inner_table.info["has_tiflash_replica"] + + def visit_vector_index(self, index, if_exists=False): + # from IPython import embed;embed() + if not self._can_drop_index(index): + return + with self.with_ddl_events(index): + # Drop vector index is the same as dropping normal index + sqlalchemy.sql.ddl.DropIndex(index, if_exists)._invoke_with(self.connection) diff --git a/tidb_vector/sqlalchemy/dialect.py b/tidb_vector/sqlalchemy/dialect.py new file mode 100644 index 0000000..ff89233 --- /dev/null +++ b/tidb_vector/sqlalchemy/dialect.py @@ -0,0 +1,45 @@ +import sqlalchemy.dialects.mysql as sqlalchemy_mysql + +from .compiler import TiDBDDLCompiler + + +class TiDBDialect_mysqldb(sqlalchemy_mysql.mysqldb.MySQLDialect_mysqldb): + name = "tidb" + driver = "mysqldb" + + preparer = sqlalchemy_mysql.base.MySQLIdentifierPreparer + ddl_compiler = TiDBDDLCompiler + statement_compiler = sqlalchemy_mysql.base.MySQLCompiler + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def initialize(self, connection): + super().initialize(connection) + + @classmethod + def import_dbapi(cls): + return __import__("MySQLdb") + + +class TiDBDialect_pymysql(sqlalchemy_mysql.pymysql.MySQLDialect_pymysql): + name = "tidb" + driver = "pymysql" + + preparer = sqlalchemy_mysql.base.MySQLIdentifierPreparer + ddl_compiler = TiDBDDLCompiler + statement_compiler = sqlalchemy_mysql.base.MySQLCompiler + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def initialize(self, connection): + super().initialize(connection) + + @classmethod + def import_dbapi(cls): + return __import__("pymysql") + + +dialect_mysqldb = TiDBDialect_mysqldb +dialect_pymysql = TiDBDialect_pymysql diff --git a/tidb_vector/sqlalchemy/ext/__init__.py b/tidb_vector/sqlalchemy/ext/__init__.py new file mode 100644 index 0000000..e69de29 diff --git a/tidb_vector/sqlalchemy/ext/declarative.py b/tidb_vector/sqlalchemy/ext/declarative.py new file mode 100644 index 0000000..d97ca8b --- /dev/null +++ b/tidb_vector/sqlalchemy/ext/declarative.py @@ -0,0 +1,25 @@ +from typing import Dict, Any, Optional + +from sqlalchemy.ext.declarative import DeclarativeMeta +from sqlalchemy.orm import declarative_base + +from ..ddl import Table, MetaData + + +class TiDBDeclarativeMeta(DeclarativeMeta): + def __new__(cls, name: str, bases, class_dict: Dict[str, Any]): + # Overwrite __table_cls__ to make sure all create table use + # the custom `..ddl.Table` class. + if "__table_cls__" not in class_dict: + class_dict["__table_cls__"] = Table + + return DeclarativeMeta.__new__(cls, name, bases, class_dict) + + +def get_declarative_base(metadata: Optional[MetaData] = None): + if metadata is None: + metadata = MetaData() + else: + # ensure the metadata is created with type `..ddl.MetaData` + assert isinstance(metadata, MetaData) + return declarative_base(metadata=metadata, metaclass=TiDBDeclarativeMeta)