Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove the schema prefix when we do joins #10

Closed
wants to merge 1 commit into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@ jobs:
run: >-
python3 -m
pip install
build pytest sqlalchemy
build pytest sqlalchemy rockset
--user
- name: Test
run: if [ $(ls -1 test/*.py |grep -v __init__.py |wc -l) -gt 0 ]; then pytest test; else true; fi
Expand Down
15 changes: 6 additions & 9 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

setup(
name="rockset-sqlalchemy",
version="1.0.0",
version="1.0.1",
author="Rockset",
author_email="[email protected]",
keywords=["Rockset", "rockset-client"],
Expand All @@ -15,16 +15,13 @@
entry_points={
"sqlalchemy.dialects": [
"rockset_sqlalchemy = rockset_sqlalchemy.sqlalchemy:RocksetDialect",
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect"
"rockset = rockset_sqlalchemy.sqlalchemy:RocksetDialect",
]
},
install_requires=[
"rockset>=1.0.0",
"sqlalchemy>=1.4.0"
],
install_requires=["rockset>=1.0.0", "sqlalchemy>=1.4.0"],
classifiers=[
'Programming Language :: Python :: 3',
'License :: OSI Approved :: Apache Software License',
'Operating System :: OS Independent',
"Programming Language :: Python :: 3",
"License :: OSI Approved :: Apache Software License",
"Operating System :: OS Independent",
],
)
41 changes: 39 additions & 2 deletions src/rockset_sqlalchemy/sqlalchemy/compiler.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import sqlalchemy as sa
from sqlalchemy import func
from sqlalchemy.sql import compiler
from sqlalchemy.sql import compiler, elements, selectable
from sqlalchemy.sql.operators import custom_op, json_getitem_op, json_path_getitem_op

from .types import Array
Expand Down Expand Up @@ -49,3 +48,41 @@ def _element_at(self, b):

def get_from_hint_text(self, table, text):
return text

def _alter_column_table_in_clause(self, obj):
if isinstance(obj, elements.BooleanClauseList):
for clause in obj.clauses:
self._alter_column_table_in_clause(clause)
elif isinstance(obj, elements.BinaryExpression):
for el in [obj.left, obj.right]:
# If the element that we are visiting is not a column or it's table is not
# a selectable.TableClause (ie. it's a subquery)
# or the type of it's name is actually a truncated label (ie. it's an alias)
# then we don't need to do anything
if (
not isinstance(el, elements.ColumnElement)
or not isinstance(el.table, selectable.TableClause)
or isinstance(el.table.name, elements._truncated_label)
):
continue
effective_schema = self.preparer.schema_for_object(el.table)
# no effective_schema no issues ! we don't need to alias things
if not effective_schema:
continue
schema_prefix = self.preparer.quote_schema(effective_schema)
# no schema prefix same
if not schema_prefix or schema_prefix == "":
continue

# Ok it's a real table it has schema prefix we need to do something
# otherwise we would have something like aliasA.col1 = "schema"."tablename".col2
# and this is not working with our engine ...

# we just create an alias with the same name as the table,
# this sounds "dumb" but it's what we need to do to stop sqlalchemy from
# the table with a schema
el.table = el.table.alias(el.table.name)

def visit_join(self, join, asfrom=False, from_linter=None, **kwargs):
self._alter_column_table_in_clause(join.onclause)
return super().visit_join(join, asfrom, from_linter, **kwargs)
58 changes: 58 additions & 0 deletions test/test_compiler.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,58 @@
import sys

from sqlalchemy.dialects import registry
from sqlalchemy.sql import and_, column, table
from sqlalchemy.testing import AssertsCompiledSQL


class TestSQL(AssertsCompiledSQL):
__dialect__ = "rockset"

def setup_method(self):

sys.path.insert(0, "./src")

registry.register("rockset", "rockset_sqlalchemy.sqlalchemy", "RocksetDialect")
pass

def test_inner_join_table_on_clause_w_schema(self):
t1 = table("t1", column("x"), schema="s1")
t2 = table("t2", column("y"), schema="s2")
nd = and_(*[column("x") == t2.c.y])
# the column in the join condition is not part of a table so it shouldn't be prefixed
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"',
)
nd = and_(*[column("x") == t2.c.y])
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "x" = "t2"."y"',
)
t3 = t2.alias("t3")
col = column("y")
col.table = t3
nd = and_(*[t1.c.x == col])
# the column in the join condition is not part of a table so it shouldn't be prefixed
self.assert_compile(
t1.join(t2, nd),
'"s1"."t1" JOIN "s2"."t2" ON "t1"."x" = "t3"."y"',
)

def test_inner_join_table_on_clause(self):
t1 = table("t1", column("x"))
t2 = table("t2", column("y"))
nd = and_(*[column("x") == t2.c.y])
self.assert_compile(
t1.join(t2.alias("t3"), nd),
'"t1" JOIN "t2" AS "t3" ON "x" = "t2"."y"',
)

def test_inner_join_no_table_on_clause(self):
t1 = table("t1", column("x"))
t2 = table("t2", column("y"))
nd = and_(*[column("x") == column("y")])
self.assert_compile(
t1.join(t2.alias("t3"), nd),
'"t1" JOIN "t2" AS "t3" ON "x" = "y"',
)