From 981266f40ae00b6424902cc03aa835680d576222 Mon Sep 17 00:00:00 2001 From: Matthieu Patou Date: Sat, 2 Mar 2024 02:03:18 +0000 Subject: [PATCH] Remove the schema prefix when we do joins Summary On the columns that are part of the clause for a join, remove the schema prefix by creating an alias based on the table and setting the `table` attribute of the columns in the join clause to be alias and not the table. The name of the alias is picked to be the name of the table as well as it works in plain SQL. We only do that if: 1. there is a schema prefix on the table (ie. `commons` or `production`) 2. the "table" is not actually already an alias 3. the "table" is not actually a subquery Testing I created unit tests for the visit_join function they are all passing. I also created a wheel package and uploaded it in `superset`, I used to have issues with queries when superset wanted to do a self join to limit the number of series returned in a query. It was complaining: ``` Relation name `commons` not found. If you are trying to access a nested field within an object ``` With the fix the query is working fine. --- setup.py | 15 ++--- src/rockset_sqlalchemy/sqlalchemy/compiler.py | 41 ++++++++++++- test/test_compiler.py | 58 +++++++++++++++++++ 3 files changed, 103 insertions(+), 11 deletions(-) create mode 100644 test/test_compiler.py diff --git a/setup.py b/setup.py index a3ec982..15be918 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="rockset-sqlalchemy", - version="1.0.0", + version="1.0.1", author="Rockset", author_email="support@rockset.com", keywords=["Rockset", "rockset-client"], @@ -15,16 +15,13 @@ entry_points={ "sqlalchemy.dialects": [ "rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect", - "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect" + "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect", ] }, - install_requires=[ - "rockset>=1.0.0", - "sqlalchemy>=1.4.0" - ], + install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"], classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", ], ) diff --git a/src/rockset_sqlalchemy/sqlalchemy/compiler.py b/src/rockset_sqlalchemy/sqlalchemy/compiler.py index 046dc22..4d362d6 100644 --- a/src/rockset_sqlalchemy/sqlalchemy/compiler.py +++ b/src/rockset_sqlalchemy/sqlalchemy/compiler.py @@ -1,6 +1,5 @@ -import sqlalchemy as sa from sqlalchemy import func -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, elements, selectable from sqlalchemy.sql.operators import custom_op, json_getitem_op, json_path_getitem_op from .types import Array @@ -49,3 +48,41 @@ def _element_at(self, b): def get_from_hint_text(self, table, text): return text + + def _alter_column_table_in_clause(self, obj): + if isinstance(obj, elements.BooleanClauseList): + for clause in obj.clauses: + self._alter_column_table_in_clause(clause) + elif isinstance(obj, elements.BinaryExpression): + for el in [obj.left, obj.right]: + # If the element that we are visiting is not a column or it's table is not + # a selectable.TableClause (ie. it's a subquery) + # or the type of it's name is actually a truncated label (ie. it's an alias) + # then we don't need to do anything + if ( + not isinstance(el, elements.ColumnElement) + or not isinstance(el.table, selectable.TableClause) + or isinstance(el.table.name, elements._truncated_label) + ): + continue + effective_schema = self.preparer.schema_for_object(el.table) + # no effective_schema no issues ! we don't need to alias things + if not effective_schema: + continue + schema_prefix = self.preparer.quote_schema(effective_schema) + # no schema prefix same + if not schema_prefix or schema_prefix == "": + continue + + # Ok it's a real table it has schema prefix we need to do something + # otherwise we would have something like aliasA.col1 = "schema"."tablename".col2 + # and this is not working with our engine ... + + # we just create an alias with the same name as the table, + # this sounds "dumb" but it's what we need to do to stop sqlalchemy from + # the table with a schema + el.table = el.table.alias(el.table.name) + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + self._alter_column_table_in_clause(join.onclause) + return super().visit_join(join, asfrom, from_linter, **kwargs) diff --git a/test/test_compiler.py b/test/test_compiler.py new file mode 100644 index 0000000..8609b4c --- /dev/null +++ b/test/test_compiler.py @@ -0,0 +1,58 @@ +import sys + +from sqlalchemy.dialects import registry +from sqlalchemy.sql import and_, column, table +from sqlalchemy.testing import AssertsCompiledSQL + + +class TestSQL(AssertsCompiledSQL): + __dialect__ = "rockset" + + def setup_method(self): + + sys.path.insert(0, "./src") + + registry.register("rockset", "rockset_sqlalchemy.sqlalchemy", "RocksetDialect") + pass + + def test_inner_join_table_on_clause_w_schema(self): + t1 = table("t1", column("x"), schema="s1") + t2 = table("t2", column("y"), schema="s2") + nd = and_(*[column("x") == t2.c.y]) + # the column in the join condition is not part of a table so it shouldn't be prefixed + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"', + ) + nd = and_(*[column("x") == t2.c.y]) + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"', + ) + t3 = t2.alias("t3") + col = column("y") + col.table = t3 + nd = and_(*[t1.c.x == col]) + # the column in the join condition is not part of a table so it shouldn't be prefixed + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "t1"."x" = "t3"."y"', + ) + + def test_inner_join_table_on_clause(self): + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + nd = and_(*[column("x") == t2.c.y]) + self.assert_compile( + t1.join(t2.alias("t3"), nd), + '"t1" JOIN "t2" AS "t3" ON "x" = "t2"."y"', + ) + + def test_inner_join_no_table_on_clause(self): + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + nd = and_(*[column("x") == column("y")]) + self.assert_compile( + t1.join(t2.alias("t3"), nd), + '"t1" JOIN "t2" AS "t3" ON "x" = "y"', + )