diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 5f3da81..7aad6b0 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -18,7 +18,7 @@ jobs: run: >- python3 -m pip install - build pytest sqlalchemy + build pytest sqlalchemy rockset --user - name: Test run: if [ $(ls -1 test/*.py |grep -v __init__.py |wc -l) -gt 0 ]; then pytest test; else true; fi diff --git a/setup.py b/setup.py index a3ec982..15be918 100644 --- a/setup.py +++ b/setup.py @@ -2,7 +2,7 @@ setup( name="rockset-sqlalchemy", - version="1.0.0", + version="1.0.1", author="Rockset", author_email="support@rockset.com", keywords=["Rockset", "rockset-client"], @@ -15,16 +15,13 @@ entry_points={ "sqlalchemy.dialects": [ "rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect", - "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect" + "rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect", ] }, - install_requires=[ - "rockset>=1.0.0", - "sqlalchemy>=1.4.0" - ], + install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"], classifiers=[ - 'Programming Language :: Python :: 3', - 'License :: OSI Approved :: Apache Software License', - 'Operating System :: OS Independent', + "Programming Language :: Python :: 3", + "License :: OSI Approved :: Apache Software License", + "Operating System :: OS Independent", ], ) diff --git a/src/rockset_sqlalchemy/sqlalchemy/compiler.py b/src/rockset_sqlalchemy/sqlalchemy/compiler.py index 046dc22..4d362d6 100644 --- a/src/rockset_sqlalchemy/sqlalchemy/compiler.py +++ b/src/rockset_sqlalchemy/sqlalchemy/compiler.py @@ -1,6 +1,5 @@ -import sqlalchemy as sa from sqlalchemy import func -from sqlalchemy.sql import compiler +from sqlalchemy.sql import compiler, elements, selectable from sqlalchemy.sql.operators import custom_op, json_getitem_op, json_path_getitem_op from .types import Array @@ -49,3 +48,41 @@ def _element_at(self, b): def get_from_hint_text(self, table, text): return text + + def _alter_column_table_in_clause(self, obj): + if isinstance(obj, elements.BooleanClauseList): + for clause in obj.clauses: + self._alter_column_table_in_clause(clause) + elif isinstance(obj, elements.BinaryExpression): + for el in [obj.left, obj.right]: + # If the element that we are visiting is not a column or it's table is not + # a selectable.TableClause (ie. it's a subquery) + # or the type of it's name is actually a truncated label (ie. it's an alias) + # then we don't need to do anything + if ( + not isinstance(el, elements.ColumnElement) + or not isinstance(el.table, selectable.TableClause) + or isinstance(el.table.name, elements._truncated_label) + ): + continue + effective_schema = self.preparer.schema_for_object(el.table) + # no effective_schema no issues ! we don't need to alias things + if not effective_schema: + continue + schema_prefix = self.preparer.quote_schema(effective_schema) + # no schema prefix same + if not schema_prefix or schema_prefix == "": + continue + + # Ok it's a real table it has schema prefix we need to do something + # otherwise we would have something like aliasA.col1 = "schema"."tablename".col2 + # and this is not working with our engine ... + + # we just create an alias with the same name as the table, + # this sounds "dumb" but it's what we need to do to stop sqlalchemy from + # the table with a schema + el.table = el.table.alias(el.table.name) + + def visit_join(self, join, asfrom=False, from_linter=None, **kwargs): + self._alter_column_table_in_clause(join.onclause) + return super().visit_join(join, asfrom, from_linter, **kwargs) diff --git a/test/test_compiler.py b/test/test_compiler.py new file mode 100644 index 0000000..8609b4c --- /dev/null +++ b/test/test_compiler.py @@ -0,0 +1,58 @@ +import sys + +from sqlalchemy.dialects import registry +from sqlalchemy.sql import and_, column, table +from sqlalchemy.testing import AssertsCompiledSQL + + +class TestSQL(AssertsCompiledSQL): + __dialect__ = "rockset" + + def setup_method(self): + + sys.path.insert(0, "./src") + + registry.register("rockset", "rockset_sqlalchemy.sqlalchemy", "RocksetDialect") + pass + + def test_inner_join_table_on_clause_w_schema(self): + t1 = table("t1", column("x"), schema="s1") + t2 = table("t2", column("y"), schema="s2") + nd = and_(*[column("x") == t2.c.y]) + # the column in the join condition is not part of a table so it shouldn't be prefixed + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"', + ) + nd = and_(*[column("x") == t2.c.y]) + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"', + ) + t3 = t2.alias("t3") + col = column("y") + col.table = t3 + nd = and_(*[t1.c.x == col]) + # the column in the join condition is not part of a table so it shouldn't be prefixed + self.assert_compile( + t1.join(t2, nd), + '"s1"."t1" JOIN "s2"."t2" ON "t1"."x" = "t3"."y"', + ) + + def test_inner_join_table_on_clause(self): + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + nd = and_(*[column("x") == t2.c.y]) + self.assert_compile( + t1.join(t2.alias("t3"), nd), + '"t1" JOIN "t2" AS "t3" ON "x" = "t2"."y"', + ) + + def test_inner_join_no_table_on_clause(self): + t1 = table("t1", column("x")) + t2 = table("t2", column("y")) + nd = and_(*[column("x") == column("y")]) + self.assert_compile( + t1.join(t2.alias("t3"), nd), + '"t1" JOIN "t2" AS "t3" ON "x" = "y"', + )