Skip to content

Commit

Permalink
Remove the schema prefix when we do joins
Browse files Browse the repository at this point in the history
Summary

On the columns that are part of the clause for a join, remove the schema
prefix by creating an alias based on the table and setting the `table`
attribute of the columns in the join clause to be alias and not the
table.
The name of the alias is picked to be the name of the table as well as
it works in plain SQL.

We only do that if:

1. there is a schema prefix on the table (ie. `commons` or `production`)
2. the "table" is not actually already an alias
3. the "table" is not actually a subquery

Testing

I created unit tests for the visit_join function they are all passing.
I also created a wheel package and uploaded it in `superset`, I used to
have issues with queries when superset wanted to do a self join to limit
the number of series returned in a query.
It was complaining:
```
Relation name `commons` not found. If you are trying to access a nested
field within an object
```
With the fix the query is working fine.
  • Loading branch information
mpatou committed Mar 4, 2024
1 parent 754d356 commit c577823
Show file tree
Hide file tree
Showing 4 changed files with 104 additions and 12 deletions.
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
run: >-
python3 -m
pip install
build pytest sqlalchemy
build pytest sqlalchemy rockset
--user
- name: Test
run: if [ $(ls -1 test/*.py |grep -v __init__.py |wc -l) -gt 0 ]; then pytest test; else true; fi
Expand Down
15 changes: 6 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="rockset-sqlalchemy",
version="1.0.0",
version="1.0.1",
author="Rockset",
author_email="[email protected]",
keywords=["Rockset", "rockset-client"],
Expand All @@ -15,16 +15,13 @@
entry_points={
"sqlalchemy.dialects": [
"rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect",
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect"
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect",
]
},
install_requires=[
"rockset>=1.0.0",
"sqlalchemy>=1.4.0"
],
install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"],
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
)
41 changes: 39 additions & 2 deletions src/rockset_sqlalchemy/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, elements, selectable
from sqlalchemy.sql.operators import custom_op, json_getitem_op, json_path_getitem_op

from .types import Array
Expand Down Expand Up @@ -49,3 +48,41 @@ def _element_at(self, b):

def get_from_hint_text(self, table, text):
return text

def _alter_column_table_in_clause(self, obj):
if isinstance(obj, elements.BooleanClauseList):
for clause in obj.clauses:
self._alter_column_table_in_clause(clause)
elif isinstance(obj, elements.BinaryExpression):
for el in [obj.left, obj.right]:
# If the element that we are visiting is not a column or it's table is not
# a selectable.TableClause (ie. it's a subquery)
# or the type of it's name is actually a truncated label (ie. it's an alias)
# then we don't need to do anything
if (
not isinstance(el, elements.ColumnElement)
or not isinstance(el.table, selectable.TableClause)
or isinstance(el.table.name, elements._truncated_label)
):
continue
effective_schema = self.preparer.schema_for_object(el.table)
# no effective_schema no issues ! we don't need to alias things
if not effective_schema:
continue
schema_prefix = self.preparer.quote_schema(effective_schema)
# no schema prefix same
if not schema_prefix or schema_prefix == "":
continue

# Ok it's a real table it has schema prefix we need to do something
# otherwise we would have something like aliasA.col1 = "schema"."tablename".col2
# and this is not working with our engine ...

# we just create an alias with the same name as the table,
# this sounds "dumb" but it's what we need to do to stop sqlalchemy from
# the table with a schema
el.table = el.table.alias(el.table.name)

def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
self._alter_column_table_in_clause(join.onclause)
return super().visit_join(join, asfrom, from_linter, **kwargs)
58 changes: 58 additions & 0 deletions test/test_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

from sqlalchemy.dialects import registry
from sqlalchemy.sql import and_, column, table
from sqlalchemy.testing import AssertsCompiledSQL


class TestSQL(AssertsCompiledSQL):
__dialect__ = "rockset"

def setup_method(self):

sys.path.insert(0, "./src")

registry.register("rockset", "rockset_sqlalchemy.sqlalchemy", "RocksetDialect")
pass

def test_inner_join_table_on_clause_w_schema(self):
t1 = table("t1", column("x"), schema="s1")
t2 = table("t2", column("y"), schema="s2")
nd = and_(*[column("x") == t2.c.y])
# the column in the join condition is not part of a table so it shouldn't be prefixed
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"',
)
nd = and_(*[column("x") == t2.c.y])
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"',
)
t3 = t2.alias("t3")
col = column("y")
col.table = t3
nd = and_(*[t1.c.x == col])
# the column in the join condition is not part of a table so it shouldn't be prefixed
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "t1"."x" = "t3"."y"',
)

def test_inner_join_table_on_clause(self):
t1 = table("t1", column("x"))
t2 = table("t2", column("y"))
nd = and_(*[column("x") == t2.c.y])
self.assert_compile(
t1.join(t2.alias("t3"), nd),
'"t1" JOIN "t2" AS "t3" ON "x" = "t2"."y"',
)

def test_inner_join_no_table_on_clause(self):
t1 = table("t1", column("x"))
t2 = table("t2", column("y"))
nd = and_(*[column("x") == column("y")])
self.assert_compile(
t1.join(t2.alias("t3"), nd),
'"t1" JOIN "t2" AS "t3" ON "x" = "y"',
)

0 comments on commit c577823

Please sign in to comment.