diff --git a/sqlglot/dialects/duckdb.py b/sqlglot/dialects/duckdb.py index 1c707a858c..50c5a16124 100644 --- a/sqlglot/dialects/duckdb.py +++ b/sqlglot/dialects/duckdb.py @@ -22,7 +22,6 @@ no_datetime_sql, encode_decode_sql, build_formatted_time, - inline_array_unless_query, no_comment_column_constraint_sql, no_time_sql, no_timestamp_sql, @@ -39,6 +38,7 @@ explode_to_unnest_sql, no_make_interval_sql, groupconcat_sql, + inline_array_unless_query, regexp_replace_global_modifier, ) from sqlglot.generator import unsupported_args @@ -682,10 +682,15 @@ class Generator(generator.Generator): SUPPORTS_LIKE_QUANTIFIERS = False SET_ASSIGNMENT_REQUIRES_VARIABLE_KEYWORD = True + def _array_sql_with_struct_inheritance(self, expression: exp.Array) -> str: + """Generate DuckDB array SQL with struct field name inheritance preprocessing.""" + transformed = transforms.inherit_struct_field_names(expression) + return inline_array_unless_query(self, t.cast(exp.Array, transformed)) + TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.ApproxDistinct: approx_count_distinct_sql, - exp.Array: inline_array_unless_query, + exp.Array: lambda self, e: self._array_sql_with_struct_inheritance(e), exp.ArrayFilter: rename_func("LIST_FILTER"), exp.ArrayRemove: remove_from_array_using_filter, exp.ArraySort: _array_sort_sql, diff --git a/sqlglot/dialects/hive.py b/sqlglot/dialects/hive.py index dd68899a42..063a74f295 100644 --- a/sqlglot/dialects/hive.py +++ b/sqlglot/dialects/hive.py @@ -576,6 +576,7 @@ class Generator(generator.Generator): exp.ApproxDistinct: approx_count_distinct_sql, exp.ArgMax: arg_max_or_min_no_count("MAX_BY"), exp.ArgMin: arg_max_or_min_no_count("MIN_BY"), + exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]), exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayToString: lambda self, e: self.func("CONCAT_WS", e.expression, e.this), exp.ArraySort: _array_sort_sql, diff --git a/sqlglot/dialects/presto.py b/sqlglot/dialects/presto.py index 9c2f96e96a..60bc6b71ca 100644 --- a/sqlglot/dialects/presto.py +++ b/sqlglot/dialects/presto.py @@ -426,13 +426,18 @@ class Generator(generator.Generator): exp.DataType.Type.TIMETZ: "TIME", } + def _array_sql_with_struct_inheritance(self, expression: exp.Array) -> str: + """Generate ARRAY[...] SQL with struct field name inheritance preprocessing.""" + transformed = transforms.inherit_struct_field_names(expression) + return f"ARRAY[{self.expressions(transformed, flat=True)}]" + TRANSFORMS = { **generator.Generator.TRANSFORMS, exp.AnyValue: rename_func("ARBITRARY"), exp.ApproxQuantile: rename_func("APPROX_PERCENTILE"), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), - exp.Array: lambda self, e: f"ARRAY[{self.expressions(e, flat=True)}]", + exp.Array: lambda self, e: self._array_sql_with_struct_inheritance(e), exp.ArrayAny: rename_func("ANY_MATCH"), exp.ArrayConcat: rename_func("CONCAT"), exp.ArrayContains: rename_func("CONTAINS"), diff --git a/sqlglot/dialects/snowflake.py b/sqlglot/dialects/snowflake.py index b5223935a9..853eb233ab 100644 --- a/sqlglot/dialects/snowflake.py +++ b/sqlglot/dialects/snowflake.py @@ -1525,6 +1525,7 @@ class Generator(generator.Generator): exp.ApproxDistinct: rename_func("APPROX_COUNT_DISTINCT"), exp.ArgMax: rename_func("MAX_BY"), exp.ArgMin: rename_func("MIN_BY"), + exp.Array: transforms.preprocess([transforms.inherit_struct_field_names]), exp.ArrayConcat: lambda self, e: self.arrayconcat_sql(e, name="ARRAY_CAT"), exp.ArrayContains: lambda self, e: self.func( "ARRAY_CONTAINS", diff --git a/sqlglot/transforms.py b/sqlglot/transforms.py index 47ea1f69ce..361d1c58ee 100644 --- a/sqlglot/transforms.py +++ b/sqlglot/transforms.py @@ -999,3 +999,70 @@ def _inline_inherited_window(window: exp.Expression) -> None: _inline_inherited_window(window) return expression + + +def inherit_struct_field_names(expression: exp.Expression) -> exp.Expression: + """ + Inherit struct field names from the first struct in an array to subsequent unnamed structs. + + BigQuery (and some other dialects) allow shorthand where the first STRUCT in an array + defines field names and subsequent STRUCTs inherit them: + + Example: + ARRAY[ + STRUCT('Alice' AS name, 85 AS score), -- defines names + STRUCT('Bob', 92), -- inherits names + STRUCT('Diana', 95) -- inherits names + ] + + This transform makes the field names explicit on all structs by adding PropertyEQ nodes. + + Args: + expression: The expression tree to transform + + Returns: + The modified expression with field names inherited in all structs + """ + for array in expression.find_all(exp.Array): + if not array.expressions: + continue + + # Check if first element is a Struct with field names + first_item = array.expressions[0] + if not isinstance(first_item, exp.Struct): + continue + + # Get field names from first struct (PropertyEQ nodes) + property_eqs = [e for e in first_item.expressions if isinstance(e, exp.PropertyEQ)] + if not property_eqs: + continue + + field_names = [pe.name for pe in property_eqs] + + # Apply field names to subsequent structs that don't have them + for struct in array.expressions[1:]: + if not isinstance(struct, exp.Struct): + continue + + # Skip if struct already has field names + if struct.find(exp.PropertyEQ): + continue + + # Skip if struct has different number of fields + if len(struct.expressions) != len(field_names): + continue + + # Convert unnamed expressions to PropertyEQ with inherited names + new_expressions = [] + for i, expr in enumerate(struct.expressions): + if not isinstance(expr, exp.PropertyEQ): + # Create PropertyEQ: field_name := value + new_expressions.append( + exp.PropertyEQ(this=exp.Identifier(this=field_names[i]), expression=expr) + ) + else: + new_expressions.append(expr) + + struct.set("expressions", new_expressions) + + return expression diff --git a/tests/dialects/test_bigquery.py b/tests/dialects/test_bigquery.py index 3aa9aefec9..9602260165 100644 --- a/tests/dialects/test_bigquery.py +++ b/tests/dialects/test_bigquery.py @@ -2338,6 +2338,20 @@ def test_inline_constructor(self): "duckdb": "SELECT CAST(ROW(1, ROW('c_str')) AS STRUCT(a BIGINT, b STRUCT(c TEXT)))", }, ) + self.validate_all( + "SELECT MAX_BY(name, score) FROM table1", + write={ + "bigquery": "SELECT MAX_BY(name, score) FROM table1", + "duckdb": "SELECT ARG_MAX(name, score) FROM table1", + }, + ) + self.validate_all( + "SELECT MIN_BY(product, price) FROM table1", + write={ + "bigquery": "SELECT MIN_BY(product, price) FROM table1", + "duckdb": "SELECT ARG_MIN(product, price) FROM table1", + }, + ) def test_convert(self): for value, expected in [ @@ -2387,6 +2401,34 @@ def test_unnest(self): }, ) + self.validate_all( + "SELECT * FROM UNNEST([STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)])", + write={ + "bigquery": "SELECT * FROM UNNEST([STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)])", + "duckdb": "SELECT * FROM (SELECT UNNEST([{'name': 'Alice', 'scores': {'math': 85, 'english': 90}}, {'name': 'Bob', 'scores': {'math': 92, 'english': 88}}], max_depth => 2))", + "snowflake": "SELECT * FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('name', 'Alice', 'scores', OBJECT_CONSTRUCT('math', 85, 'english', 90)), OBJECT_CONSTRUCT('name', 'Bob', 'scores', OBJECT_CONSTRUCT('math', 92, 'english', 88))])) AS _t0(seq, key, path, index, value, this)", + "presto": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', CAST(ROW(85, 90) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER))), CAST(ROW('Bob', CAST(ROW(92, 88) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER)))])", + "trino": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', CAST(ROW(85, 90) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER))), CAST(ROW('Bob', CAST(ROW(92, 88) AS ROW(math INTEGER, english INTEGER))) AS ROW(name VARCHAR, scores ROW(math INTEGER, english INTEGER)))])", + "spark2": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)))", + "databricks": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, STRUCT(85 AS math, 90 AS english) AS scores), STRUCT('Bob' AS name, STRUCT(92 AS math, 88 AS english) AS scores)))", + "hive": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice', STRUCT(85, 90)), STRUCT('Bob', STRUCT(92, 88))))", + }, + ) + + self.validate_all( + "SELECT * FROM UNNEST([STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95)])", + write={ + "bigquery": "SELECT * FROM UNNEST([STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95)])", + "duckdb": "SELECT * FROM (SELECT UNNEST([{'name': 'Alice', 'score': 85}, {'name': 'Bob', 'score': 92}, {'name': 'Diana', 'score': 95}], max_depth => 2))", + "snowflake": "SELECT * FROM TABLE(FLATTEN(INPUT => [OBJECT_CONSTRUCT('name', 'Alice', 'score', 85), OBJECT_CONSTRUCT('name', 'Bob', 'score', 92), OBJECT_CONSTRUCT('name', 'Diana', 'score', 95)])) AS _t0(seq, key, path, index, value, this)", + "presto": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', 85) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Bob', 92) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Diana', 95) AS ROW(name VARCHAR, score INTEGER))])", + "trino": "SELECT * FROM UNNEST(ARRAY[CAST(ROW('Alice', 85) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Bob', 92) AS ROW(name VARCHAR, score INTEGER)), CAST(ROW('Diana', 95) AS ROW(name VARCHAR, score INTEGER))])", + "spark2": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score)))", + "databricks": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score)))", + "hive": "SELECT * FROM EXPLODE(ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92), STRUCT('Diana', 95)))", + }, + ) + def test_range_type(self): for type, value in ( ("RANGE", "'[2020-01-01, 2020-12-31)'"), diff --git a/tests/test_transforms.py b/tests/test_transforms.py index 8ce5366a78..b5d1325707 100644 --- a/tests/test_transforms.py +++ b/tests/test_transforms.py @@ -6,6 +6,7 @@ eliminate_join_marks, eliminate_qualify, eliminate_window_clause, + inherit_struct_field_names, remove_precision_parameterized_types, ) @@ -289,3 +290,74 @@ def test_eliminate_window_clause(self): "SELECT LAST_VALUE(c) OVER (a) AS c2 FROM (SELECT LAST_VALUE(i) OVER (a) AS c FROM p WINDOW a AS (PARTITION BY x)) AS q(c) WINDOW a AS (PARTITION BY y)", "SELECT LAST_VALUE(c) OVER (PARTITION BY y) AS c2 FROM (SELECT LAST_VALUE(i) OVER (PARTITION BY x) AS c FROM p) AS q(c)", ) + + def test_inherit_struct_field_names(self): + # Basic case: field names inherited from first struct + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Diana', 95))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Diana' AS name, 95 AS score))", + ) + + # Single struct in array: no inheritance needed + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score))", + ) + + # Empty array: no change + self.validate( + inherit_struct_field_names, + "SELECT ARRAY()", + "SELECT ARRAY()", + ) + + # First struct has no field names: no inheritance + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92))", + "SELECT ARRAY(STRUCT('Alice', 85), STRUCT('Bob', 92))", + ) + + # Mismatched field counts: skip inheritance + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob'))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob'))", + ) + + # Struct already has field names: don't override + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS fullname, 92 AS points))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS fullname, 92 AS points))", + ) + + # Mixed: some structs inherit, some already have names + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92), STRUCT('Carol' AS name, 88 AS score), STRUCT('Diana', 95))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score), STRUCT('Carol' AS name, 88 AS score), STRUCT('Diana' AS name, 95 AS score))", + ) + + # Non-struct elements: no change + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(1, 2, 3)", + "SELECT ARRAY(1, 2, 3)", + ) + + # Multiple arrays: each processed independently + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob', 92)), ARRAY(STRUCT('X' AS col), STRUCT('Y'))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85 AS score), STRUCT('Bob' AS name, 92 AS score)), ARRAY(STRUCT('X' AS col), STRUCT('Y' AS col))", + ) + + # Partial field names in first struct: inherit only the named ones + self.validate( + inherit_struct_field_names, + "SELECT ARRAY(STRUCT('Alice' AS name, 85), STRUCT('Bob', 92))", + "SELECT ARRAY(STRUCT('Alice' AS name, 85), STRUCT('Bob', 92))", + )