From 6c728683f3300016d18cc89f317d121ab2ce10fe Mon Sep 17 00:00:00 2001 From: Andreas Motl Date: Mon, 24 Jun 2024 17:07:13 +0200 Subject: [PATCH] Support: Generalize one more use case into `support.util.refresh_table` --- src/sqlalchemy_cratedb/support/polyfill.py | 5 +---- src/sqlalchemy_cratedb/support/util.py | 15 +++++++++++---- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/sqlalchemy_cratedb/support/polyfill.py b/src/sqlalchemy_cratedb/support/polyfill.py index 230af0d..73177e5 100644 --- a/src/sqlalchemy_cratedb/support/polyfill.py +++ b/src/sqlalchemy_cratedb/support/polyfill.py @@ -108,10 +108,7 @@ def receive_after_execute( ): if isinstance(clauseelement, (sa.sql.Insert, sa.sql.Update, sa.sql.Delete)): if not isinstance(clauseelement.table, sa.sql.Join): - full_table_name = f'"{clauseelement.table.name}"' - if clauseelement.table.schema is not None: - full_table_name = f'"{clauseelement.table.schema}".' + full_table_name - refresh_table(conn, full_table_name) + refresh_table(conn, clauseelement.table) sa.event.listen(engine, "after_execute", receive_after_execute) diff --git a/src/sqlalchemy_cratedb/support/util.py b/src/sqlalchemy_cratedb/support/util.py index 1defc93..33cce5f 100644 --- a/src/sqlalchemy_cratedb/support/util.py +++ b/src/sqlalchemy_cratedb/support/util.py @@ -10,14 +10,21 @@ pass -def refresh_table(connection, target: t.Union[str, "DeclarativeBase"]): +def refresh_table(connection, target: t.Union[str, "DeclarativeBase", "sa.sql.selectable.TableClause"]): """ Invoke a `REFRESH TABLE` statement. """ - if hasattr(target, "__tablename__"): - sql = f"REFRESH TABLE {target.__tablename__}" + + if isinstance(target, sa.sql.selectable.TableClause): + full_table_name = f'"{target.name}"' + if target.schema is not None: + full_table_name = f'"{target.schema}".' + full_table_name + elif hasattr(target, "__tablename__"): + full_table_name = target.__tablename__ else: - sql = f"REFRESH TABLE {target}" + full_table_name = target + + sql = f"REFRESH TABLE {full_table_name}" connection.execute(sa.text(sql))