Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[DNM] Implement dialect "tidb+pymysql"/"tidb+mysqldb" for sqlalchemy #63

Closed
wants to merge 12 commits into from
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,7 @@ jobs:
runs-on: ubuntu-latest
services:
tidb:
image: wangdi4zm/tind:v7.5.3-vector-index
image: wangdi4zm/tind:v8.4.0-vector-index
ports:
- 4000:4000
steps:
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -112,6 +112,9 @@ venv/
ENV/
env.bak/
venv.bak/
bin/
pyvenv.cfg
share/

# Spyder project settings
.spyderproject
Expand All @@ -137,6 +140,7 @@ dmypy.json
# Cython debug symbols
cython_debug/

.vscode/
.idea/
django_tests_dir

Expand Down
34 changes: 19 additions & 15 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -35,33 +35,37 @@ Learn how to connect to TiDB Serverless in the [TiDB Cloud documentation](https:
Define table with vector field

```python
from sqlalchemy import Column, Integer, create_engine
from sqlalchemy.orm import declarative_base
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Column, Integer, create_engine, func
from sqlalchemy.orm import Session
from tidb_vector.sqlalchemy import VectorType, VectorIndex, get_declarative_base

engine = create_engine('mysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test')
Base = declarative_base()
engine = create_engine('tidb+pymysql://****.root:******@gateway01.xxxxxx.shared.aws.tidbcloud.com:4000/test')
Base = get_declarative_base()

class Test(Base):
__tablename__ = 'test'
class Document(Base):
__tablename__ = 'sqlalchemy_demo_documents'
id = Column(Integer, primary_key=True)
embedding = Column(VectorType(3))

# or add hnsw index when creating table
class TestWithIndex(Base):
__tablename__ = 'test_with_index'
# or add hnsw index
class DocumentWithIndex(Base):
__tablename__ = 'sqlalchemy_demo_documents_with_index'
id = Column(Integer, primary_key=True)
embedding = Column(VectorType(3), comment="hnsw(distance=l2)")
embedding = Column(VectorType(3))
__table_args__ = (
VectorIndex('idx_l2', text('(vec_l2_distance(embedding))')),
)

Base.metadata.create_all(engine)
```

Insert vector data

```python
test = Test(embedding=[1, 2, 3])
session.add(test)
session.commit()
with Session(engine) as session:
test = Test(embedding=[1, 2, 3])
session.add(test)
session.commit()
```

Get the nearest neighbors
Expand Down Expand Up @@ -252,4 +256,4 @@ There are some examples to show how to use the tidb-vector-python to interact wi
for more examples, see the [examples](./examples) directory.

## Contributing
Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.
Please feel free to reach out to the maintainers if you have any questions or need help with the project. Before contributing, please read the [CONTRIBUTING.md](./CONTRIBUTING.md) file.
2 changes: 1 addition & 1 deletion examples/orm-sqlalchemy-quickstart/.env.example
Original file line number Diff line number Diff line change
@@ -1 +1 @@
TIDB_DATABASE_URL=mysql+pymysql://<USERNAME>:<PASSWORD>@<HOST>:4000/<DATABASE>?ssl_ca=<CA>&ssl_verify_cert=true&ssl_verify_identity=true
TIDB_DATABASE_URL=tidb+pymysql://<USERNAME>:<PASSWORD>@<HOST>:4000/<DATABASE>?ssl_ca=<CA>&ssl_verify_cert=true&ssl_verify_identity=true
4 changes: 2 additions & 2 deletions examples/orm-sqlalchemy-quickstart/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ cp .env.example .env
Copy the `HOST`, `PORT`, `USERNAME`, `PASSWORD`, `DATABASE`, and `CA` parameters from the TiDB Cloud console (see [Prerequisites](../README.md#prerequisites)), and then replace the placeholders in the `.env` file.

```bash
TIDB_DATABASE_URL=mysql+pymysql://<USERNAME>:<PASSWORD>@<HOST>:4000/<DATABASE>?ssl_ca=<CA>&ssl_verify_cert=true&ssl_verify_identity=true
TIDB_DATABASE_URL=tidb+pymysql://<USERNAME>:<PASSWORD>@<HOST>:4000/<DATABASE>?ssl_ca=<CA>&ssl_verify_cert=true&ssl_verify_identity=true
```

### Run this example
Expand All @@ -59,4 +59,4 @@ Get documents within a certain distance:
document: fish
- distance: 0.12712843905603044
document: dog
```
```
17 changes: 9 additions & 8 deletions examples/orm-sqlalchemy-quickstart/sqlalchemy-quickstart.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
import os
import dotenv

from sqlalchemy import Column, Integer, create_engine, Text
from sqlalchemy.orm import declarative_base, Session
from tidb_vector.sqlalchemy import VectorType
from sqlalchemy import Column, Integer, create_engine, Text, text
from sqlalchemy.orm import Session
from tidb_vector.sqlalchemy import VectorType, VectorIndex, get_declarative_base

dotenv.load_dotenv()

Expand All @@ -12,7 +12,7 @@
engine = create_engine(tidb_connection_string)

# Step 2: Define a table with a vector column.
Base = declarative_base()
Base = get_declarative_base()


class Document(Base):
Expand All @@ -22,18 +22,19 @@ class Document(Base):
embedding = Column(VectorType(3))


# Or add HNSW index when creating table.
# Or add HNSW index
class DocumentWithIndex(Base):
__tablename__ = 'sqlalchemy_demo_documents_with_index'
id = Column(Integer, primary_key=True)
content = Column(Text)
embedding = Column(VectorType(3), comment="hnsw(distance=cosine)")

embedding = Column(VectorType(3))
__table_args__ = (
VectorIndex('idx_cos', text('(vec_cosine_distance(embedding))')),
)

Base.metadata.drop_all(engine)
Base.metadata.create_all(engine)


# Step 3: Insert embeddings into the table.
with Session(engine) as session:
session.add(Document(content="dog", embedding=[1, 2, 1]))
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test TiDB Vector Search functionality."""

from __future__ import annotations

from tidb_vector.integrations.utils import extract_info_from_column_definition
Expand Down
1 change: 1 addition & 0 deletions tests/integrations/test_vector_client.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
"""Test TiDB Vector Search functionality."""

from __future__ import annotations

from typing import List, Tuple
Expand Down
207 changes: 197 additions & 10 deletions tests/sqlalchemy/test_sqlalchemy.py
Original file line number Diff line number Diff line change
@@ -1,27 +1,35 @@
import pytest
import numpy as np
from sqlalchemy import URL, create_engine, Column, Integer, select
from sqlalchemy.orm import declarative_base, sessionmaker
import sqlalchemy
from sqlalchemy import URL, create_engine, Column, Integer, select, Index
from sqlalchemy.orm import sessionmaker
from sqlalchemy.exc import OperationalError
from tidb_vector.sqlalchemy import VectorType
from tidb_vector.sqlalchemy import (
VectorType,
VectorIndex,
TiFlashReplica,
get_declarative_base,
)
from ..config import TestConfig


database_name = "test"
db_url = URL(
"mysql+pymysql",
"tidb+pymysql",
username=TestConfig.TIDB_USER,
password=TestConfig.TIDB_PASSWORD,
host=TestConfig.TIDB_HOST,
port=TestConfig.TIDB_PORT,
database="test",
query={"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {},
database=database_name,
query=(
{"ssl_verify_cert": True, "ssl_verify_identity": True}
if TestConfig.TIDB_SSL
else {}
),
)

engine = create_engine(db_url)
Session = sessionmaker(bind=engine)
Base = declarative_base()
Base = get_declarative_base()


class Item1Model(Base):
Expand Down Expand Up @@ -303,3 +311,182 @@ def test_negative_inner_product(self):
)
assert len(items) == 2
assert items[1].distance == -14.0


class TestSQLAlchemyDDL:
def setup_class(self):
Item2Model.__table__.drop(bind=engine, checkfirst=True)
Item2Model.__table__.create(bind=engine)

def teardown_class(self):
Item2Model.__table__.drop(bind=engine, checkfirst=True)

def setup_method(self):
with Session() as session:
session.query(Item2Model).delete()
session.commit()

@staticmethod
def check_indexes(table, expect_indexes_name):
indexes = table.indexes
indexes_name = [index.name for index in indexes]
assert len(indexes) == len(expect_indexes_name)
for i in expect_indexes_name:
assert i in indexes_name

def test_alter_tiflash_replica(self):
# Add tiflash replica
replica = TiFlashReplica(Item2Model.__table__, num=1)
replica.create(engine)
sql = sqlalchemy.text(
f"""
SELECT TABLE_SCHEMA,TABLE_NAME,REPLICA_COUNT
FROM INFORMATION_SCHEMA.TIFLASH_REPLICA
WHERE TABLE_SCHEMA="{database_name}" AND TABLE_NAME="{Item2Model.__tablename__}"
"""
)
with Session() as session:
rs = session.execute(sql)
for r in rs:
assert r.TABLE_SCHEMA == database_name
assert r.TABLE_NAME == Item2Model.__tablename__
assert r.REPLICA_COUNT == 1
assert Item2Model.__table__.info["has_tiflash_replica"]

# Drop tiflash replica
replica.drop(engine)
with Session() as session:
rs = session.execute(sql)
for r in rs:
assert r.TABLE_SCHEMA == database_name
assert r.TABLE_NAME == Item2Model.__tablename__
assert r.REPLICA_COUNT == 0
assert not Item2Model.__table__.info.get("has_tiflash_replica", False)

def test_query_with_index(self):
# indexes
l2_index = VectorIndex(
"idx_embedding_l2",
sqlalchemy.func.vec_l2_distance(Item2Model.__table__.c.embedding),
)
l2_index.create(engine)
cos_index = VectorIndex(
"idx_embedding_cos",
sqlalchemy.func.vec_cosine_distance(Item2Model.__table__.c.embedding),
)
cos_index.create(engine)

self.check_indexes(
Item2Model.__table__, ["idx_embedding_l2", "idx_embedding_cos"]
)

with Session() as session:
session.add_all(
[Item2Model(embedding=[1, 2, 3]), Item2Model(embedding=[1, 2, 3.2])]
)
session.commit()

# l2 distance
result_l2 = session.scalars(
select(Item2Model).filter(
Item2Model.embedding.l2_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_l2) == 2

distance_l2 = Item2Model.embedding.l2_distance([1, 2, 3])
items_l2 = (
session.query(Item2Model.id, distance_l2.label("distance"))
.order_by(distance_l2)
.limit(5)
.all()
)
assert len(items_l2) == 2
assert items_l2[0].distance == 0.0

# cosine distance
result_cos = session.scalars(
select(Item2Model).filter(
Item2Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_cos) == 2

distance_cos = Item2Model.embedding.cosine_distance([1, 2, 3])
items_cos = (
session.query(Item2Model.id, distance_cos.label("distance"))
.order_by(distance_cos)
.limit(5)
.all()
)
assert len(items_cos) == 2
assert items_cos[0].distance == 0.0

# drop indexes
l2_index.drop(engine)
cos_index.drop(engine)

def test_query_with_inited_index(self):
class Item3Model(Base):
__tablename__ = "sqlalchemy_item3"
id = Column(Integer, primary_key=True)
embedding = Column(VectorType(dim=3))
__table_args__ = (
Index("idx_id", "id"),
VectorIndex(
"idx_embedding_l2", sqlalchemy.text("(vec_l2_distance(embedding))")
),
VectorIndex(
"idx_embedding_cos",
sqlalchemy.text("(vec_cosine_distance(embedding))"),
),
)

Item3Model.__table__.drop(bind=engine, checkfirst=True)
Item3Model.__table__.create(bind=engine)

self.check_indexes(
Item3Model.__table__, ["idx_id", "idx_embedding_l2", "idx_embedding_cos"]
)

with Session() as session:
session.add_all(
[Item3Model(embedding=[1, 2, 3]), Item3Model(embedding=[1, 2, 3.2])]
)
session.commit()

# l2 distance
result_l2 = session.scalars(
select(Item3Model).filter(
Item3Model.embedding.l2_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_l2) == 2

distance_l2 = Item3Model.embedding.l2_distance([1, 2, 3])
items_l2 = (
session.query(Item3Model.id, distance_l2.label("distance"))
.order_by(distance_l2)
.limit(5)
.all()
)
assert len(items_l2) == 2
assert items_l2[0].distance == 0.0

# cosine distance
result_cos = session.scalars(
select(Item3Model).filter(
Item3Model.embedding.cosine_distance([1, 2, 3.1]) < 0.2
)
).all()
assert len(result_cos) == 2

distance_cos = Item3Model.embedding.cosine_distance([1, 2, 3])
items_cos = (
session.query(Item3Model.id, distance_cos.label("distance"))
.order_by(distance_cos)
.limit(5)
.all()
)
assert len(items_cos) == 2
assert items_cos[0].distance == 0.0
Loading
Loading