diff --git a/sqlglot/optimizer/scope.py b/sqlglot/optimizer/scope.py index 4077964ee5..398a22faac 100644 --- a/sqlglot/optimizer/scope.py +++ b/sqlglot/optimizer/scope.py @@ -358,7 +358,7 @@ def references(self) -> t.List[t.Tuple[str, exp.Expression]]: for expression in itertools.chain(self.derived_tables, self.udtfs): self._references.append( ( - expression.alias, + _get_source_alias(expression), expression if expression.args.get("pivots") else expression.unnest(), ) ) @@ -785,7 +785,7 @@ def _traverse_tables(scope): # This shouldn't be a problem once qualify_columns runs, as it adds aliases on everything. # Until then, this means that only a single, unaliased derived table is allowed (rather, # the latest one wins. - sources[expression.alias] = child_scope + sources[_get_source_alias(expression)] = child_scope # append the final child_scope yielded if child_scope: @@ -825,7 +825,7 @@ def _traverse_udtfs(scope): ): yield child_scope top = child_scope - sources[expression.alias] = child_scope + sources[_get_source_alias(expression)] = child_scope scope.subquery_scopes.append(top) @@ -915,3 +915,13 @@ def find_in_scope(expression, expression_types, bfs=True): the criteria was found. """ return next(find_all_in_scope(expression, expression_types, bfs=bfs), None) + + +def _get_source_alias(expression): + alias_arg = expression.args.get("alias") + alias_name = expression.alias + + if not alias_name and isinstance(alias_arg, exp.TableAlias) and len(alias_arg.columns) == 1: + alias_name = alias_arg.columns[0].name + + return alias_name diff --git a/tests/test_optimizer.py b/tests/test_optimizer.py index 212869a4ff..512358dcab 100644 --- a/tests/test_optimizer.py +++ b/tests/test_optimizer.py @@ -871,6 +871,10 @@ def test_scope(self): sql = "UPDATE tbl1 SET col = 0" self.assertEqual(len(traverse_scope(parse_one(sql))), 0) + sql = "SELECT * FROM t LEFT JOIN UNNEST(a) AS a1 LEFT JOIN UNNEST(a1.a) AS a2" + scope = build_scope(parse_one(sql, read="bigquery")) + self.assertEqual(set(scope.selected_sources), {"t", "a1", "a2"}) + @patch("sqlglot.optimizer.scope.logger") def test_scope_warning(self, logger): self.assertEqual(len(traverse_scope(parse_one("WITH q AS (@y) SELECT * FROM q"))), 1)