From d5319c80d3883425734062640ebd8562cf8525fb Mon Sep 17 00:00:00 2001 From: Phil Date: Wed, 4 Sep 2024 12:36:15 +0200 Subject: [PATCH] fix(get_view_names): Use proper schema (#1082) * Fix get view names func * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci * wip * Use information_schema.tables * [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --------- Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com> Co-authored-by: Elliana May --- duckdb_engine/__init__.py | 27 ++++++++++++++++++++++----- duckdb_engine/tests/test_basic.py | 3 +++ 2 files changed, 25 insertions(+), 5 deletions(-) diff --git a/duckdb_engine/__init__.py b/duckdb_engine/__init__.py index a42e1ae1..a34e9a44 100644 --- a/duckdb_engine/__init__.py +++ b/duckdb_engine/__init__.py @@ -320,12 +320,29 @@ def get_view_names( include: Optional[Any] = None, **kw: Any, ) -> Any: - s = "SELECT table_name FROM information_schema.tables WHERE table_type='VIEW' and table_schema=:schema_name" - rs = connection.execute( - text(s), {"schema_name": schema if schema is not None else "main"} - ) + s = """ + SELECT table_name + FROM information_schema.tables + WHERE + table_type='VIEW' + AND table_schema = :schema_name + """ + params = {} + database_name = None + + if schema is not None: + database_name, schema = self.identifier_preparer._separate(schema) + else: + schema = "main" - return [row[0] for row in rs] + params.update({"schema_name": schema}) + + if database_name is not None: + s += "AND table_catalog = :database_name\n" + params.update({"database_name": database_name}) + + rs = connection.execute(text(s), params) + return [view for (view,) in rs] @cache # type: ignore[call-arg] def get_schema_names(self, connection: "Connection", **kw: "Any"): # type: ignore[no-untyped-def] diff --git a/duckdb_engine/tests/test_basic.py b/duckdb_engine/tests/test_basic.py index 24e52cbd..63f5411d 100644 --- a/duckdb_engine/tests/test_basic.py +++ b/duckdb_engine/tests/test_basic.py @@ -248,6 +248,9 @@ def test_get_views(conn: Connection, dialect: Dialect) -> None: views = dialect.get_view_names(conn, schema="scheme") assert views == ["schema_test"] + views = dialect.get_view_names(conn, schema="memory.scheme") + assert views == ["schema_test"] + assert dialect.has_table(conn, table_name="test") assert dialect.has_table(conn, table_name="schema_test", schema="scheme")