-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Support: Add re-usable patches and poly-fills from application adapters
Sources: MLflow, LangChain, Singer/Meltano, rdflib-sqlalchemy
- Loading branch information
Showing
6 changed files
with
307 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,5 +1,13 @@ | ||
from sqlalchemy_cratedb.support.pandas import insert_bulk | ||
from sqlalchemy_cratedb.support.polyfill import check_uniqueness_factory, refresh_after_dml, \ | ||
patch_autoincrement_timestamp | ||
from sqlalchemy_cratedb.support.util import refresh_table, refresh_dirty | ||
|
||
__all__ = [ | ||
check_uniqueness_factory, | ||
insert_bulk, | ||
patch_autoincrement_timestamp, | ||
refresh_after_dml, | ||
refresh_dirty, | ||
refresh_table, | ||
] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,128 @@ | ||
import sqlalchemy as sa | ||
from sqlalchemy.event import listen | ||
import typing as t | ||
|
||
from sqlalchemy_cratedb.support.util import refresh_dirty, refresh_table | ||
|
||
|
||
def patch_autoincrement_timestamp(): | ||
""" | ||
Configure SQLAlchemy model columns with an alternative to `autoincrement=True`. | ||
Use the current timestamp instead. | ||
This is used by CrateDB's MLflow adapter. | ||
TODO: Maybe enable through a dialect parameter `crate_polyfill_autoincrement` or such. | ||
""" | ||
import sqlalchemy.sql.schema as schema | ||
|
||
init_dist = schema.Column.__init__ | ||
|
||
def __init__(self, *args, **kwargs): | ||
if "autoincrement" in kwargs: | ||
del kwargs["autoincrement"] | ||
if "default" not in kwargs: | ||
kwargs["default"] = sa.func.now() | ||
init_dist(self, *args, **kwargs) | ||
|
||
schema.Column.__init__ = __init__ # type: ignore[method-assign] | ||
|
||
|
||
def check_uniqueness_factory(sa_entity, *attribute_names): | ||
""" | ||
Run a manual column value uniqueness check on a table, and raise an IntegrityError if applicable. | ||
CrateDB does not support the UNIQUE constraint on columns. This attempts to emulate it. | ||
https://github.com/crate/sqlalchemy-cratedb/issues/76 | ||
This is used by CrateDB's MLflow adapter. | ||
TODO: Maybe enable through a dialect parameter `crate_polyfill_unique` or such. | ||
""" | ||
|
||
# Synthesize a canonical "name" for the constraint, | ||
# composed of all column names involved. | ||
constraint_name: str = "-".join(attribute_names) | ||
|
||
def check_uniqueness(mapper, connection, target): | ||
from sqlalchemy.exc import IntegrityError | ||
|
||
if isinstance(target, sa_entity): | ||
# TODO: How to use `session.query(SqlExperiment)` here? | ||
stmt = mapper.selectable.select() | ||
for attribute_name in attribute_names: | ||
stmt = stmt.filter(getattr(sa_entity, attribute_name) == getattr(target, attribute_name)) | ||
stmt = stmt.compile(bind=connection.engine) | ||
results = connection.execute(stmt) | ||
if results.rowcount > 0: | ||
raise IntegrityError( | ||
statement=stmt, | ||
params=[], | ||
orig=Exception( | ||
f"DuplicateKeyException in table '{target.__tablename__}' " f"on constraint '{constraint_name}'" | ||
), | ||
) | ||
|
||
return check_uniqueness | ||
|
||
|
||
def refresh_after_dml_session(session: sa.orm.Session): | ||
""" | ||
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE). | ||
CrateDB is eventually consistent, i.e. write operations are not flushed to | ||
disk immediately, so readers may see stale data. In a traditional OLTP-like | ||
application, this is not applicable. | ||
This SQLAlchemy extension makes sure that data is synchronized after each | ||
operation manipulating data. | ||
> `after_{insert,update,delete}` events only apply to the session flush operation | ||
> and do not apply to the ORM DML operations described at ORM-Enabled INSERT, | ||
> UPDATE, and DELETE statements. To intercept ORM DML events, use | ||
> `SessionEvents.do_orm_execute().` | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.MapperEvents.after_insert | ||
> Intercept statement executions that occur on behalf of an ORM Session object. | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.do_orm_execute | ||
> Execute after flush has completed, but before commit has been called. | ||
> -- https://docs.sqlalchemy.org/en/20/orm/events.html#sqlalchemy.orm.SessionEvents.after_flush | ||
This is used by CrateDB's LangChain adapter. | ||
TODO: Maybe enable through a dialect parameter `crate_dml_refresh` or such. | ||
""" # noqa: E501 | ||
listen(session, "after_flush", refresh_dirty) | ||
|
||
|
||
def refresh_after_dml_engine(engine: sa.engine.Engine): | ||
""" | ||
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE). | ||
This is used by CrateDB's Singer/Meltano and `rdflib-sqlalchemy` adapters. | ||
""" | ||
def receive_after_execute( | ||
conn: sa.engine.Connection, clauseelement, multiparams, params, execution_options, result | ||
): | ||
if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)): | ||
if not isinstance(clauseelement.table, sa.sql.Join): | ||
full_table_name = f'"{clauseelement.table.name}"' | ||
if clauseelement.table.schema is not None: | ||
full_table_name = f'"{clauseelement.table.schema}".' + full_table_name | ||
refresh_table(conn, full_table_name) | ||
|
||
sa.event.listen(engine, "after_execute", receive_after_execute) | ||
|
||
|
||
def refresh_after_dml(engine_or_session: t.Union[sa.engine.Engine, sa.orm.Session]): | ||
""" | ||
Run `REFRESH TABLE` after each DML operation (INSERT, UPDATE, DELETE). | ||
""" | ||
if isinstance(engine_or_session, sa.engine.Engine): | ||
refresh_after_dml_engine(engine_or_session) | ||
elif isinstance(engine_or_session, (sa.orm.Session, sa.orm.scoping.scoped_session)): | ||
refresh_after_dml_session(engine_or_session) | ||
else: | ||
raise TypeError(f"Unknown type: {type(engine_or_session)}") |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,32 @@ | ||
import itertools | ||
import typing as t | ||
|
||
import sqlalchemy as sa | ||
try: | ||
from sqlalchemy.orm import DeclarativeBase | ||
except: | ||
pass | ||
|
||
|
||
def refresh_table(connection, target: t.Union[str, "DeclarativeBase"]): | ||
""" | ||
Invoke a `REFRESH TABLE` statement. | ||
""" | ||
if hasattr(target, "__tablename__"): | ||
sql = f"REFRESH TABLE {target.__tablename__}" | ||
else: | ||
sql = f"REFRESH TABLE {target}" | ||
connection.execute(sa.text(sql)) | ||
|
||
|
||
def refresh_dirty(session, flush_context=None): | ||
""" | ||
Invoke a `REFRESH TABLE` statement on each table entity flagged as "dirty". | ||
SQLAlchemy event handler for the 'after_flush' event, | ||
invoking `REFRESH TABLE` on each table which has been modified. | ||
""" | ||
dirty_entities = itertools.chain(session.new, session.dirty, session.deleted) | ||
dirty_classes = {entity.__class__ for entity in dirty_entities} | ||
for class_ in dirty_classes: | ||
refresh_table(session, class_) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,137 @@ | ||
import datetime as dt | ||
|
||
import pytest | ||
import sqlalchemy as sa | ||
from sqlalchemy.event import listen | ||
from sqlalchemy.exc import IntegrityError | ||
from sqlalchemy.orm import sessionmaker | ||
|
||
from sqlalchemy_cratedb import SA_VERSION, SA_1_4 | ||
|
||
try: | ||
from sqlalchemy.orm import declarative_base | ||
except ImportError: | ||
from sqlalchemy.ext.declarative import declarative_base | ||
|
||
from sqlalchemy_cratedb.support import check_uniqueness_factory, patch_autoincrement_timestamp, refresh_after_dml | ||
|
||
|
||
@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Test case not supported on SQLAlchemy 1.3 and earlier") | ||
def test_autoincrement_timestamp(cratedb_service): | ||
""" | ||
Validate autoincrement columns using `sa.DateTime` columns. | ||
https://github.com/crate/sqlalchemy-cratedb/issues/77 | ||
""" | ||
patch_autoincrement_timestamp() | ||
|
||
engine = cratedb_service.database.engine | ||
session = sessionmaker(bind=engine)() | ||
Base = declarative_base() | ||
|
||
# Define DDL. | ||
class FooBar(Base): | ||
__tablename__ = 'foobar' | ||
id = sa.Column(sa.String, primary_key=True) | ||
date = sa.Column(sa.DateTime, autoincrement=True) | ||
number = sa.Column(sa.BigInteger, autoincrement=True) | ||
string = sa.Column(sa.String, autoincrement=True) | ||
|
||
Base.metadata.drop_all(engine, checkfirst=True) | ||
Base.metadata.create_all(engine, checkfirst=True) | ||
|
||
# Insert record. | ||
foo_item = FooBar(id="foo") | ||
session.add(foo_item) | ||
session.commit() | ||
session.execute(sa.text("REFRESH TABLE foobar")) | ||
|
||
# Query record. | ||
result = session.execute(sa.select(FooBar.date, FooBar.number, FooBar.string)).mappings().first() | ||
|
||
# Compare outcome. | ||
assert result["date"].year == dt.datetime.now().year | ||
assert result["number"] >= 1718846016235 | ||
assert result["string"] >= "1718846016235" | ||
|
||
|
||
@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier") | ||
def test_check_uniqueness_factory(cratedb_service): | ||
""" | ||
Validate basic synthetic UNIQUE constraints. | ||
https://github.com/crate/sqlalchemy-cratedb/issues/76 | ||
""" | ||
|
||
engine = cratedb_service.database.engine | ||
session = sessionmaker(bind=engine)() | ||
Base = declarative_base() | ||
|
||
# Define DDL. | ||
class FooBar(Base): | ||
__tablename__ = 'foobar' | ||
id = sa.Column(sa.String, primary_key=True) | ||
name = sa.Column(sa.String) | ||
|
||
# Add synthetic UNIQUE constraint on `name` column. | ||
listen(FooBar, "before_insert", check_uniqueness_factory(FooBar, "name")) | ||
|
||
Base.metadata.drop_all(engine, checkfirst=True) | ||
Base.metadata.create_all(engine, checkfirst=True) | ||
|
||
# Insert baseline record. | ||
foo_item = FooBar(id="foo", name="foo") | ||
session.add(foo_item) | ||
session.commit() | ||
session.execute(sa.text("REFRESH TABLE foobar")) | ||
|
||
# Insert second record, violating the uniqueness constraint. | ||
bar_item = FooBar(id="bar", name="foo") | ||
session.add(bar_item) | ||
with pytest.raises(IntegrityError) as ex: | ||
session.commit() | ||
assert ex.match("DuplicateKeyException in table 'foobar' on constraint 'name'") | ||
|
||
|
||
@pytest.mark.skipif(SA_VERSION < SA_1_4, reason="Feature not supported on SQLAlchemy 1.3 and earlier") | ||
@pytest.mark.parametrize("mode", ["engine", "session"]) | ||
def test_refresh_after_dml(cratedb_service, mode): | ||
""" | ||
Validate automatic `REFRESH TABLE` issuing works well. | ||
https://github.com/crate/sqlalchemy-cratedb/issues/83 | ||
""" | ||
engine = cratedb_service.database.engine | ||
session = sessionmaker(bind=engine)() | ||
Base = declarative_base() | ||
|
||
# Enable automatic refresh. | ||
if mode == "engine": | ||
refresh_after_dml(engine) | ||
elif mode == "session": | ||
refresh_after_dml(session) | ||
else: | ||
raise ValueError(f"Unable to enable automatic refresh with mode: {mode}") | ||
|
||
# Define DDL. | ||
class FooBar(Base): | ||
__tablename__ = 'foobar' | ||
id = sa.Column(sa.String, primary_key=True) | ||
|
||
Base.metadata.drop_all(engine, checkfirst=True) | ||
Base.metadata.create_all(engine, checkfirst=True) | ||
|
||
# Insert baseline record. | ||
foo_item = FooBar(id="foo") | ||
session.add(foo_item) | ||
session.commit() | ||
|
||
# Query record. | ||
query = session.query(FooBar.id) | ||
result = query.first() | ||
|
||
# Sanity checks. | ||
assert result is not None, "Database result is empty. Most probably, `REFRESH TABLE` wasn't issued." | ||
|
||
# Compare outcome. | ||
assert result[0] == "foo" |