Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions sqlglot/dialects/duckdb.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,6 @@
no_datetime_sql,
encode_decode_sql,
build_formatted_time,
inline_array_unless_query,
no_comment_column_constraint_sql,
no_time_sql,
no_timestamp_sql,
Expand All @@ -39,6 +38,7 @@
explode_to_unnest_sql,
no_make_interval_sql,
groupconcat_sql,
inline_array_unless_query,
regexp_replace_global_modifier,
)
from sqlglot.generator import unsupported_args
Expand Down Expand Up @@ -682,10 +682,15 @@ class Generator(generator.Generator):
SUPPORTS_LIKE_QUANTIFIERS = False
SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD = True

def _array_sql_with_struct_inheritance(self, expression: exp.Array) -> str:
"""Generate DuckDB array SQL with struct field name inheritance preprocessing."""
transformed = transforms.inherit_struct_field_names(expression)
return inline_array_unless_query(self, t.cast(exp.Array, transformed))

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.ApproxDistinct: approx_count_distinct_sql,
exp.Array: inline_array_unless_query,
exp.Array: lambda self, e: self._array_sql_with_struct_inheritance(e),
exp.ArrayFilter: rename_func("LIST_FILTER"),
exp.ArrayRemove: remove_from_array_using_filter,
exp.ArraySort: _array_sort_sql,
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/hive.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,6 +576,7 @@ class Generator(generator.Generator):
exp.ApproxDistinct: approx_count_distinct_sql,
exp.ArgMax: arg_max_or_min_no_count("MAX_BY"),
exp.ArgMin: arg_max_or_min_no_count("MIN_BY"),
exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this),
exp.ArraySort: _array_sort_sql,
Expand Down
7 changes: 6 additions & 1 deletion sqlglot/dialects/presto.py
Original file line number Diff line number Diff line change
Expand Up @@ -426,13 +426,18 @@ class Generator(generator.Generator):
exp.DataType.Type.TIMETZ: "TIME",
}

def _array_sql_with_struct_inheritance(self, expression: exp.Array) -> str:
"""Generate ARRAY[...] SQL with struct field name inheritance preprocessing."""
transformed = transforms.inherit_struct_field_names(expression)
return f"ARRAY[{self.expressions(transformed, flat=True)}]"

TRANSFORMS = {
**generator.Generator.TRANSFORMS,
exp.AnyValue: rename_func("ARBITRARY"),
exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"),
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]",
exp.Array: lambda self, e: self._array_sql_with_struct_inheritance(e),
exp.ArrayAny: rename_func("ANY_MATCH"),
exp.ArrayConcat: rename_func("CONCAT"),
exp.ArrayContains: rename_func("CONTAINS"),
Expand Down
1 change: 1 addition & 0 deletions sqlglot/dialects/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -1525,6 +1525,7 @@ class Generator(generator.Generator):
exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"),
exp.ArgMax: rename_func("MAX_BY"),
exp.ArgMin: rename_func("MIN_BY"),
exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]),
exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"),
exp.ArrayContains: lambda self, e: self.func(
"ARRAY_CONTAINS",
Expand Down
67 changes: 67 additions & 0 deletions sqlglot/transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -999,3 +999,70 @@ def _inline_inherited_window(window: exp.Expression) -> None:
_inline_inherited_window(window)

return expression


def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression:
"""
Inherit struct field names from the first struct in an array to subsequent unnamed structs.
BigQuery (and some other dialects) allow shorthand where the first STRUCT in an array
defines field names and subsequent STRUCTs inherit them:
Example:
ARRAY[
STRUCT('Alice' AS name, 85 AS score), -- defines names
STRUCT('Bob', 92), -- inherits names
STRUCT('Diana', 95) -- inherits names
]
This transform makes the field names explicit on all structs by adding PropertyEQ nodes.
Args:
expression: The expression tree to transform
Returns:
The modified expression with field names inherited in all structs
"""
for array in expression.find_all(exp.Array):
if not array.expressions:
continue

# Check if first element is a Struct with field names
first_item = array.expressions[0]
if not isinstance(first_item, exp.Struct):
continue

# Get field names from first struct (PropertyEQ nodes)
property_eqs = [e for e in first_item.expressions if isinstance(e, exp.PropertyEQ)]
if not property_eqs:
continue

field_names = [pe.name for pe in property_eqs]

# Apply field names to subsequent structs that don't have them
for struct in array.expressions[1:]:
if not isinstance(struct, exp.Struct):
continue

# Skip if struct already has field names
if struct.find(exp.PropertyEQ):
continue

# Skip if struct has different number of fields
if len(struct.expressions) != len(field_names):
continue

# Convert unnamed expressions to PropertyEQ with inherited names
new_expressions = []
for i, expr in enumerate(struct.expressions):
if not isinstance(expr, exp.PropertyEQ):
# Create PropertyEQ: field_name := value
new_expressions.append(
exp.PropertyEQ(this=exp.Identifier(this=field_names[i]), expression=expr)
)
else:
new_expressions.append(expr)

struct.set("expressions", new_expressions)

return expression
42 changes: 42 additions & 0 deletions tests/dialects/test_bigquery.py
Original file line number Diff line number Diff line change
Expand Up @@ -2338,6 +2338,20 @@ def test_inline_constructor(self):
"duckdb": "SELECT CAST(ROW(1, ROW('c_str')) AS STRUCT(a BIGINT, b STRUCT(c TEXT)))",
},
)
self.validate_all(
"SELECT MAX_BY(name, score) FROM table1",
write={
"bigquery": "SELECT MAX_BY(name, score) FROM table1",
"duckdb": "SELECT ARG_MAX(name, score) FROM table1",
},
)
self.validate_all(
"SELECT MIN_BY(product, price) FROM table1",
write={
"bigquery": "SELECT MIN_BY(product, price) FROM table1",
"duckdb": "SELECT ARG_MIN(product, price) FROM table1",
},
)

def test_convert(self):
for value, expected in [
Expand Down Expand Up @@ -2387,6 +2401,34 @@ def test_unnest(self):
},
)

self.validate_all(
"SELECT * FROM UNNEST([STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)])",
write={
"bigquery": "SELECT * FROM UNNEST([STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)])",
"duckdb": "SELECT * FROM (SELECT UNNEST([{'name': 'Alice', 'scores': {'math': 85, 'english': 90}}, {'name': 'Bob', 'scores': {'math': 92, 'english': 88}}], max_depth => 2))",
"snowflake": "SELECT * FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('name', 'Alice', 'scores', OBJECT_CONSTRUCT('math', 85, 'english', 90)), OBJECT_CONSTRUCT('name', 'Bob', 'scores', OBJECT_CONSTRUCT('math', 92, 'english', 88))])) AS _t0(seq, key, path, index, value, this)",
"presto": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', CAST(ROW(85, 90) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER))), CAST(ROW('Bob', CAST(ROW(92, 88) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER)))])",
"trino": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', CAST(ROW(85, 90) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER))), CAST(ROW('Bob', CAST(ROW(92, 88) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER)))])",
"spark2": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)))",
"databricks": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)))",
"hive": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice', STRUCT(85, 90)), STRUCT('Bob', STRUCT(92, 88))))",
},
)

self.validate_all(
"SELECT * FROM UNNEST([STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95)])",
write={
"bigquery": "SELECT * FROM UNNEST([STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95)])",
"duckdb": "SELECT * FROM (SELECT UNNEST([{'name': 'Alice', 'score': 85}, {'name': 'Bob', 'score': 92}, {'name': 'Diana', 'score': 95}], max_depth => 2))",
"snowflake": "SELECT * FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('name', 'Alice', 'score', 85), OBJECT_CONSTRUCT('name', 'Bob', 'score', 92), OBJECT_CONSTRUCT('name', 'Diana', 'score', 95)])) AS _t0(seq, key, path, index, value, this)",
"presto": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', 85) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Bob', 92) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Diana', 95) AS ROW(name VARCHAR, score INTEGER))])",
"trino": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', 85) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Bob', 92) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Diana', 95) AS ROW(name VARCHAR, score INTEGER))])",
"spark2": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score)))",
"databricks": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score)))",
"hive": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92), STRUCT('Diana', 95)))",
},
)

def test_range_type(self):
for type, value in (
("RANGE<DATE>", "'[2020-01-01, 2020-12-31)'"),
Expand Down
72 changes: 72 additions & 0 deletions tests/test_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
eliminate_join_marks,
eliminate_qualify,
eliminate_window_clause,
inherit_struct_field_names,
remove_precision_parameterized_types,
)

Expand Down Expand Up @@ -289,3 +290,74 @@ def test_eliminate_window_clause(self):
"SELECT LAST_VALUE(c) OVER (a) AS c2 FROM (SELECT LAST_VALUE(i) OVER (a) AS c FROM p WINDOW a AS (PARTITION BY x)) AS q(c) WINDOW a AS (PARTITION BY y)",
"SELECT LAST_VALUE(c) OVER (PARTITION BY y) AS c2 FROM (SELECT LAST_VALUE(i) OVER (PARTITION BY x) AS c FROM p) AS q(c)",
)

def test_inherit_struct_field_names(self):
# Basic case: field names inherited from first struct
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score))",
)

# Single struct in array: no inheritance needed
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score))",
)

# Empty array: no change
self.validate(
inherit_struct_field_names,
"SELECT ARRAY()",
"SELECT ARRAY()",
)

# First struct has no field names: no inheritance
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92))",
"SELECT ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92))",
)

# Mismatched field counts: skip inheritance
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob'))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob'))",
)

# Struct already has field names: don't override
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS fullname, 92 AS points))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS fullname, 92 AS points))",
)

# Mixed: some structs inherit, some already have names
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Carol' AS name, 88 AS score), STRUCT('Diana', 95))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Carol' AS name, 88 AS score), STRUCT('Diana' AS name, 95 AS score))",
)

# Non-struct elements: no change
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(1, 2, 3)",
"SELECT ARRAY(1, 2, 3)",
)

# Multiple arrays: each processed independently
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92)), ARRAY(STRUCT('X' AS col), STRUCT('Y'))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score)), ARRAY(STRUCT('X' AS col), STRUCT('Y' AS col))",
)

# Partial field names in first struct: inherit only the named ones
self.validate(
inherit_struct_field_names,
"SELECT ARRAY(STRUCT('Alice' AS name, 85), STRUCT('Bob', 92))",
"SELECT ARRAY(STRUCT('Alice' AS name, 85), STRUCT('Bob', 92))",
)