-
We have a multi-client architecture, where the clients are identified by a Now we have the requirement that the users must be able to create their own SQL statements for reporting. The task is now to ensure that all statements are limited to the user's client ID, so that one user cannot access another user's data. For this purpose I've successfully managed to inject/modify a import sqlglot
sqls = [
"SELECT * FROM sysusertab",
"SELECT * FROM sysusertab WHERE cid = 42",
"SELECT * FROM sysusertab WHERE 42 = cid",
"SELECT * FROM sysusertab WHERE cid = 100",
"SELECT * FROM sysusertab WHERE user = 'steve'",
"SELECT * FROM sysusertab WHERE user = 'steve' OR user = 'nick'",
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.username = userdetails.username",
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.cid = userdetails.cid AND sysusertab.username = userdetails.username",
]
expected = [
"SELECT * FROM sysusertab WHERE cid = 100", # add condition
"SELECT * FROM sysusertab WHERE cid = 100", # change condition
"SELECT * FROM sysusertab WHERE cid = 100", # change condition
"SELECT * FROM sysusertab WHERE cid = 100", # leave as-is
"SELECT * FROM sysusertab WHERE cid = 100 AND user = 'steve'", # add condition (preferably at front)
"SELECT * FROM sysusertab WHERE cid = 100 AND (user = 'steve' OR user = 'nick')", # add condition (preferably at front)
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.cid = userdetails.cid AND sysusertab.username = userdetails.username WHERE sysusertab.cid = 100", # add join condition and add condition for at least one table
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.cid = userdetails.cid AND sysusertab.username = userdetails.username WHERE sysusertab.cid = 100", # add join condition and add condition for at least one table
]
for sql in sqls:
expression = sqlglot.parse_one(sql)
cid_condi = sqlglot.condition("cid=100")
found = False
for node in expression.find_all(sqlglot.exp.Condition):
for e in (node.this, node.expression):
if isinstance(e, sqlglot.exp.Column) and e.name == "cid":
node.replace(cid_condi)
found = True
for node in expression.find_all(sqlglot.exp.Join):
# TODO: make sure tables are joined on column "cid"
# alternative: make sure all tables have a "cid=x" condition
pass
if not found:
expression = expression.where(cid_condi)
print(expression.sql()) Output:
As you can see, my code makes the existing joines worse, because not only does it remove a valid join on The generated code doesn't have to be exactly the same as the one in the Any help is greatly appreciated. |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 6 replies
-
I would replace all the tables with a subquery that has the filter. from sqlglot import exp, parse_one
from sqlglot.optimizer.scope import traverse_scope
def find_tables(expression: exp.Expression) -> list[exp.Table]:
"""
Find references to physical tables in an expression.
This excludes sqlglot Table expressions that are references to CTEs.
"""
return [
source
for scope in traverse_scope(expression)
for source in scope.sources.values()
if isinstance(source, exp.Table)
]
def filter_by_cid(sql: str, cid: int) -> str:
expression = parse_one(sql)
for table in find_tables(expression):
table.replace(
exp.select("*").from_(table.copy()).where(f"cid = {cid}").subquery(table.alias_or_name)
)
return expression.sql()
sqls = [
"SELECT * FROM sysusertab",
"SELECT * FROM sysusertab WHERE cid = 42",
"SELECT * FROM sysusertab WHERE 42 = cid",
"SELECT * FROM sysusertab WHERE cid = 100",
"SELECT * FROM sysusertab WHERE user = 'steve'",
"SELECT * FROM sysusertab WHERE user = 'steve' OR user = 'nick'",
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.username = userdetails.username",
"SELECT * FROM sysusertab JOIN userdetails ON sysusertab.cid = userdetails.cid AND sysusertab.username = userdetails.username",
]
for sql in sqls:
print(filter_by_cid(sql, 100)) Output: SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab WHERE cid = 42
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab WHERE 42 = cid
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab WHERE cid = 100
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab WHERE user = 'steve'
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab WHERE user = 'steve' OR user = 'nick'
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab JOIN (SELECT * FROM userdetails WHERE cid = 100) AS userdetails ON sysusertab.username = userdetails.username
SELECT * FROM (SELECT * FROM sysusertab WHERE cid = 100) AS sysusertab JOIN (SELECT * FROM userdetails WHERE cid = 100) AS userdetails ON sysusertab.cid = userdetails.cid AND sysusertab.username = userdetails.username This seems more foolproof to me. If you're worried about the queries being a little more verbose, you can use the optimizer. That will handle merging the subqueries and deduplicating any cid predicates. |
Beta Was this translation helpful? Give feedback.
I would replace all the tables with a subquery that has the filter.