Skip to content

Commit

Permalink
address PR review comment
Browse files Browse the repository at this point in the history
  • Loading branch information
MikeWallis42 committed Jan 13, 2025
1 parent 697d2f3 commit 36e8b31
Showing 1 changed file with 17 additions and 15 deletions.
32 changes: 17 additions & 15 deletions sqlglot/dialects/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -1570,23 +1570,25 @@ def normalize(identifier: t.Optional[exp.Identifier]) -> t.Optional[str]:
targets.add(normalize(alias.this))

for when in expression.args["whens"].expressions:
# only remove the target names from the THEN clause
# theyre still valid in the <condition> part of WHEN MATCHED / WHEN NOT MATCHED
# ref: https://github.com/TobikoData/sqlmesh/issues/2934
then = when.args.get("then")
# only remove the target table names from certain parts of WHEN MATCHED / WHEN NOT MATCHED
# they are still valid in the <condition>, the right hand side of each UPDATE and the VALUES part
# (not the column list) of the INSERT
then: exp.Insert | exp.Update | None = when.args.get("then")
if then:
then.transform(
lambda node: (
exp.column(node.this)
if isinstance(then, exp.Update):
for equals in then.find_all(exp.EQ):
equal_lhs = equals.this
if (
isinstance(node, exp.Column)
and normalize(node.args.get("table")) in targets
and not isinstance(node.parent, exp.Func)
)
else node
),
copy=False,
)
isinstance(equal_lhs, exp.Column)
and normalize(equal_lhs.args.get("table")) in targets
):
equal_lhs.replace(exp.column(equal_lhs.this))
if isinstance(then, exp.Insert):
column_list = then.args.get("this")
if column_list is not None:
for column in column_list.find_all(exp.Column):
if normalize(column.args.get("table")) in targets:
column.replace(exp.column(column.this))

return self.merge_sql(expression)

Expand Down

0 comments on commit 36e8b31

Please sign in to comment.