Skip to content

Commit

Permalink
Drop vector index, drop tiflash replica
Browse files Browse the repository at this point in the history
Signed-off-by: JaySon-Huang <[email protected]>
  • Loading branch information
JaySon-Huang committed Oct 23, 2024
1 parent f7fb8c4 commit 7a90d31
Show file tree
Hide file tree
Showing 5 changed files with 128 additions and 51 deletions.
1 change: 1 addition & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -114,6 +114,7 @@ env.bak/
venv.bak/
bin/
pyvenv.cfg
share/

# Spyder project settings
.spyderproject
Expand Down
78 changes: 57 additions & 21 deletions tests/sqlalchemy/test_sqlalchemy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,22 @@
from sqlalchemy import URL, create_engine, Column, Integer, select
from sqlalchemy.orm import declarative_base, sessionmaker
from sqlalchemy.exc import OperationalError
from tidb_vector.sqlalchemy import VectorType, VectorIndex
from tidb_vector.sqlalchemy import VectorType, VectorIndex, TiFlashReplica
from ..config import TestConfig


database_name = "test"
db_url = URL(
"tidb+pymysql",
username=TestConfig.TIDB_USER,
password=TestConfig.TIDB_PASSWORD,
host=TestConfig.TIDB_HOST,
port=TestConfig.TIDB_PORT,
database="test",
query={"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {},
database=database_name,
query=(
{"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {}
),
)

engine = create_engine(db_url)
Expand Down Expand Up @@ -319,23 +321,53 @@ def setup_method(self):
session.query(Item2Model).delete()
session.commit()

def test_l2_distance(self):
def test_alter_tiflash_replica(self):
# Add tiflash replica
replica = TiFlashReplica(Item2Model.__table__, num=1)
replica.create(engine)
sql = sqlalchemy.text(
f"""
SELECT TABLE_SCHEMA,TABLE_NAME,REPLICA_COUNT
FROM INFORMATION_SCHEMA.TIFLASH_REPLICA
WHERE TABLE_SCHEMA="{database_name}" AND TABLE_NAME="{Item2Model.__tablename__}"
"""
)
with Session() as session:
# indexes
l2_index = VectorIndex(
"idx_embedding_l2",
sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding),
)
l2_index.create(engine)
cos_index = VectorIndex(
"idx_embedding_cos",
sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding),
)
cos_index.create(engine)
rs = session.execute(sql)
for r in rs:
assert r.TABLE_SCHEMA == database_name
assert r.TABLE_NAME == Item2Model.__tablename__
assert r.REPLICA_COUNT == 1
assert Item2Model.__table__.info["has_tiflash_replica"] == True

# Drop tiflash replica
replica.drop(engine)
with Session() as session:
rs = session.execute(sql)
for r in rs:
assert r.TABLE_SCHEMA == database_name
assert r.TABLE_NAME == Item2Model.__tablename__
assert r.REPLICA_COUNT == 0
assert Item2Model.__table__.info.get("has_tiflash_replica", False) == False

def test_query_with_index(self):
# indexes
l2_index = VectorIndex(
"idx_embedding_l2",
sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding),
)
l2_index.create(engine)
cos_index = VectorIndex(
"idx_embedding_cos",
sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding),
)
cos_index.create(engine)

item1 = Item2Model(embedding=[1, 2, 3])
item2 = Item2Model(embedding=[1, 2, 3.2])
session.add_all([item1, item2])
with Session() as session:

session.add_all(
[Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])]
)
session.commit()

# l2 distance
Expand Down Expand Up @@ -373,3 +405,7 @@ def test_l2_distance(self):
)
assert len(items_cos) == 2
assert items_cos[0].distance == 0.0

# drop indexes
l2_index.drop(engine)
cos_index.drop(engine)
10 changes: 6 additions & 4 deletions tidb_vector/sqlalchemy/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from tidb_vector.constants import MAX_DIMENSION_LENGTH, MIN_DIMENSION_LENGTH
from tidb_vector.utils import decode_vector, encode_vector
from .ddl import VectorIndex
from .ddl import VectorIndex, TiFlashReplica

_registry.register(
"tidb.mysqldb",
Expand All @@ -29,7 +29,7 @@ def __init__(self, dim=None):
if dim is not None and not isinstance(dim, int):
raise ValueError("expected dimension to be an integer or None")

# tidb vector dimention length has limitation
# tidb vector dimension length has limitation
if dim is not None and (
dim < MIN_DIMENSION_LENGTH or dim > MAX_DIMENSION_LENGTH
):
Expand Down Expand Up @@ -94,7 +94,9 @@ def negative_inner_product(self, other):
"negative_inner_product"
)


__all__ = (
'VectorType',
'VectorIndex',
"VectorType",
"VectorIndex",
"TiFlashReplica",
)
6 changes: 3 additions & 3 deletions tidb_vector/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from sqlalchemy.dialects.mysql.base import MySQLDDLCompiler
from sqlalchemy.sql import elements, operators, functions

from .ddl import CreateTiFlashReplica, CreateVectorIndex
from .ddl import AlterTiFlashReplica, CreateVectorIndex


class TiDBDDLCompiler(MySQLDDLCompiler):
def visit_tiflash_replica(self, replica: CreateTiFlashReplica, **kw):
def visit_alter_tiflash_replica(self, replica: AlterTiFlashReplica, **kw):
# from IPython import embed;embed()
return "ALTER TABLE {} SET TIFLASH REPLICA {}".format(
replica.element.inner_table.name, replica.element.replica_num
replica.element.inner_table.name, replica.new_num
)

def visit_create_vector_index(self, create: CreateVectorIndex, **kw):
Expand Down
84 changes: 61 additions & 23 deletions tidb_vector/sqlalchemy/ddl.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import sqlalchemy
from sqlalchemy.sql.ddl import SchemaGenerator as SchemaGeneratorBase
from sqlalchemy.sql.ddl import SchemaDropper as SchemaDropperBase
from sqlalchemy.sql.base import DialectKWArgs
from sqlalchemy.sql.schema import (
ColumnCollectionMixin,
Expand All @@ -13,11 +14,13 @@
import sqlalchemy.exc as exc


class TiFlashReplica(DialectKWArgs):
class TiFlashReplica(DialectKWArgs, SchemaItem):
"""Represent the tiflash replica table attribute"""

__visit_name__ = "tiflash_replica"

inner_table: Optional[Table]

@property
def bind(self):
return self.metadata.bind
Expand All @@ -26,28 +29,25 @@ def bind(self):
def metadata(self):
return self.inner_table.metadata

def __init__(
self, inner_table: sqlalchemy.sql.schema.Table, num=1, *args, **kwargs
) -> None:
def __init__(self, inner_table: Table, num=1, *args, **kwargs) -> None:
super().__init__()
self.inner_table = inner_table
self.replica_num = num
# set the metadata to the inner_table
self.inner_table.info["has_tiflash_replica"] = True

def create(self, bind=None):
"""Issue a ``SET TIFLASH REPLICA`` statement"""
"""Issue a ``ALTER TABLE ... SET TIFLASH REPLICA {num}`` statement"""
if bind is None:
bind = self.bind
bind._run_ddl_visitor(TiDBSchemaGenerator, self)

def drop(self, bind=None):
"""Issue a ``SET TIFLASH REPLICA`` statement"""
"""Issue a ``ALTER TABLE ... SET TIFLASH REPLICA 0`` statement"""
if bind is None:
bind = self.bind
# TODO: implement drop tiflash replica
# bind._run_ddl_visitor()
raise NotImplementedError()
bind._run_ddl_visitor(TiDBSchemaDropper, self)
self.replica_num = 0


class VectorIndex(DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, SchemaItem):
Expand All @@ -57,6 +57,14 @@ class VectorIndex(DialectKWArgs, ColumnCollectionMixin, HasConditionalDDL, Schem
expressions: Sequence[Union[str, ColumnElement[Any]]]
_table_bound_expressions: Sequence[ColumnElement[Any]]

@property
def bind(self):
return self.metadata.bind

@property
def metadata(self):
return self.table.metadata

def __init__(
self,
name: Optional[str],
Expand Down Expand Up @@ -115,25 +123,28 @@ def create(self, bind, checkfirst: bool = False) -> None:
:class:`.VectorIndex`, using the given
:class:`.Connection` or :class:`.Engine`` for connectivity.
"""
if bind is None:
bind = self.bind
bind._run_ddl_visitor(TiDBSchemaGenerator, self, checkfirst=checkfirst)

def drop(self, bind, checkfirst: bool = False) -> None:
"""Issue a ``DROP`` statement for this
:class:`.VectorIndex`, using the given
:class:`.Connection` or :class:`.Engine` for connectivity.
"""
# bind._run_ddl_visitor(,
# self, checkfirst=checkfirst)
raise NotImplementedError()
if bind is None:
bind = self.bind
bind._run_ddl_visitor(TiDBSchemaDropper, self, checkfirst=checkfirst)


class CreateTiFlashReplica(sqlalchemy.sql.ddl._CreateDropBase):
class AlterTiFlashReplica(sqlalchemy.sql.ddl._CreateDropBase):
"""Represent a ``ALTER TABLE ... SET TIFLASH REPLICA ...`` statement."""

__visit_name__: str = "tiflash_replica"
__visit_name__: str = "alter_tiflash_replica"

def __init__(self, element):
super(CreateTiFlashReplica, self).__init__(element)
def __init__(self, element, new_num):
super().__init__(element)
self.new_num = new_num


class CreateVectorIndex(sqlalchemy.sql.ddl.CreateIndex):
Expand All @@ -146,17 +157,19 @@ def __init__(self, element, if_not_exists=False):


class TiDBSchemaGenerator(SchemaGeneratorBase):
def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super(TiDBSchemaGenerator, self).__init__(
dialect, connection, checkfirst, tables, **kwargs
)
"""Building logical CERATE ... statements."""

def __init__(self, dialect, connection, checkfirst=False, tables=None):
super().__init__(dialect, connection, checkfirst, tables)

def visit_tiflash_replica(self, replica, **kwargs):
"""
replica: TiFlashReplica
"""
with self.with_ddl_events(replica):
self.connection.execute(CreateTiFlashReplica(replica, **kwargs))
AlterTiFlashReplica(
replica, new_num=replica.replica_num, **kwargs
)._invoke_with(self.connection)

def visit_vector_index(self, index):
"""
Expand All @@ -167,7 +180,32 @@ def visit_vector_index(self, index):
with self.with_ddl_events(index):
# Automatically add tiflash replica if not exist
if not index.table.info.get("has_tiflash_replica", False):
replica = TiFlashReplica(index.table, 1)
CreateTiFlashReplica(replica)._invoke_with(self.connection)
num_replica = 1 # by default 1 replica
replica = TiFlashReplica(index.table, num_replica)
AlterTiFlashReplica(replica, new_num=num_replica)._invoke_with(
self.connection
)
# Create the vector index
CreateVectorIndex(index)._invoke_with(self.connection)


class TiDBSchemaDropper(SchemaDropperBase):
"""Building logical DROP ... statements."""

def __init__(self, dialect, connection, checkfirst=False, tables=None, **kwargs):
super().__init__(dialect, connection, checkfirst, tables, **kwargs)

def visit_tiflash_replica(self, replica, **kwargs):
with self.with_ddl_events(replica):
# set the replica_num to new_num
AlterTiFlashReplica(replica, new_num=0)._invoke_with(self.connection)
# reset the table.info of has_tiflash_replica
del replica.inner_table.info["has_tiflash_replica"]

def visit_vector_index(self, index, if_exists=False):
# from IPython import embed;embed()
if not self._can_drop_index(index):
return
with self.with_ddl_events(index):
# Drop vector index is the same as dropping normal index
sqlalchemy.sql.ddl.DropIndex(index, if_exists)._invoke_with(self.connection)

0 comments on commit 7a90d31

Please sign in to comment.