Skip to content

Commit a99dc4f

Browse files
authored
Spark 4.0: Add schema conversion support for default values (#14407)
1 parent d1a518f commit a99dc4f

File tree

3 files changed

+440
-0
lines changed

3 files changed

+440
-0
lines changed

spark/v4.0/spark/src/main/java/org/apache/iceberg/spark/TypeToSparkType.java

Lines changed: 17 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
import org.apache.iceberg.types.Type;
2626
import org.apache.iceberg.types.TypeUtil;
2727
import org.apache.iceberg.types.Types;
28+
import org.apache.spark.sql.catalyst.expressions.Literal$;
2829
import org.apache.spark.sql.types.ArrayType$;
2930
import org.apache.spark.sql.types.BinaryType$;
3031
import org.apache.spark.sql.types.BooleanType$;
@@ -69,6 +70,22 @@ public DataType struct(Types.StructType struct, List<DataType> fieldResults) {
6970
if (field.doc() != null) {
7071
sparkField = sparkField.withComment(field.doc());
7172
}
73+
74+
// Convert both write and initial default values to Spark SQL string literal representations
75+
// on the StructField metadata
76+
if (field.writeDefault() != null) {
77+
Object writeDefault = SparkUtil.internalToSpark(field.type(), field.writeDefault());
78+
sparkField =
79+
sparkField.withCurrentDefaultValue(Literal$.MODULE$.create(writeDefault, type).sql());
80+
}
81+
82+
if (field.initialDefault() != null) {
83+
Object initialDefault = SparkUtil.internalToSpark(field.type(), field.initialDefault());
84+
sparkField =
85+
sparkField.withExistenceDefaultValue(
86+
Literal$.MODULE$.create(initialDefault, type).sql());
87+
}
88+
7289
sparkFields.add(sparkField);
7390
}
7491

spark/v4.0/spark/src/test/java/org/apache/iceberg/spark/TestSparkSchemaUtil.java

Lines changed: 189 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,15 +21,27 @@
2121
import static org.apache.iceberg.types.Types.NestedField.optional;
2222
import static org.assertj.core.api.Assertions.assertThat;
2323

24+
import java.math.BigDecimal;
25+
import java.nio.ByteBuffer;
2426
import java.util.List;
27+
import java.util.TimeZone;
28+
import java.util.stream.Stream;
2529
import org.apache.iceberg.MetadataColumns;
2630
import org.apache.iceberg.Schema;
31+
import org.apache.iceberg.expressions.Literal;
32+
import org.apache.iceberg.types.Type;
2733
import org.apache.iceberg.types.Types;
2834
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
2935
import org.apache.spark.sql.catalyst.expressions.MetadataAttribute;
3036
import org.apache.spark.sql.catalyst.types.DataTypeUtils;
37+
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils$;
38+
import org.apache.spark.sql.types.Metadata;
39+
import org.apache.spark.sql.types.StructField;
3140
import org.apache.spark.sql.types.StructType;
3241
import org.junit.jupiter.api.Test;
42+
import org.junit.jupiter.params.ParameterizedTest;
43+
import org.junit.jupiter.params.provider.Arguments;
44+
import org.junit.jupiter.params.provider.MethodSource;
3345

3446
public class TestSparkSchemaUtil {
3547
private static final Schema TEST_SCHEMA =
@@ -80,4 +92,181 @@ public void testSchemaConversionWithMetaDataColumnSchema() {
8092
}
8193
}
8294
}
95+
96+
@Test
97+
public void testSchemaConversionWithOnlyWriteDefault() {
98+
Schema schema =
99+
new Schema(
100+
Types.NestedField.optional("field")
101+
.withId(1)
102+
.ofType(Types.StringType.get())
103+
.withWriteDefault(Literal.of("write_only"))
104+
.build());
105+
106+
StructType sparkSchema = SparkSchemaUtil.convert(schema);
107+
Metadata metadata = sparkSchema.fields()[0].metadata();
108+
109+
assertThat(
110+
metadata.contains(
111+
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
112+
.as("Field with only write default should have CURRENT_DEFAULT metadata")
113+
.isTrue();
114+
assertThat(
115+
metadata.contains(
116+
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
117+
.as("Field with only write default should not have EXISTS_DEFAULT metadata")
118+
.isFalse();
119+
assertThat(
120+
metadata.getString(
121+
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
122+
.as("Spark metadata CURRENT_DEFAULT should contain correctly formatted literal")
123+
.isEqualTo("'write_only'");
124+
}
125+
126+
@Test
127+
public void testSchemaConversionWithOnlyInitialDefault() {
128+
Schema schema =
129+
new Schema(
130+
Types.NestedField.optional("field")
131+
.withId(1)
132+
.ofType(Types.IntegerType.get())
133+
.withInitialDefault(Literal.of(42))
134+
.build());
135+
136+
StructType sparkSchema = SparkSchemaUtil.convert(schema);
137+
Metadata metadata = sparkSchema.fields()[0].metadata();
138+
139+
assertThat(
140+
metadata.contains(
141+
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
142+
.as("Field with only initial default should not have CURRENT_DEFAULT metadata")
143+
.isFalse();
144+
assertThat(
145+
metadata.contains(
146+
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
147+
.as("Field with only initial default should have EXISTS_DEFAULT metadata")
148+
.isTrue();
149+
assertThat(
150+
metadata.getString(
151+
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
152+
.as("Spark metadata EXISTS_DEFAULT should contain correctly formatted literal")
153+
.isEqualTo("42");
154+
}
155+
156+
@ParameterizedTest(name = "{0} with writeDefault={1}, initialDefault={2}")
157+
@MethodSource("schemaConversionWithDefaultsTestCases")
158+
public void testSchemaConversionWithDefaultsForPrimitiveTypes(
159+
Type type,
160+
Literal<?> writeDefault,
161+
Literal<?> initialDefault,
162+
String expectedCurrentDefaultValue,
163+
String expectedExistsDefaultValue) {
164+
TimeZone systemTimeZone = TimeZone.getDefault();
165+
try {
166+
TimeZone.setDefault(TimeZone.getTimeZone("UTC"));
167+
Schema schema =
168+
new Schema(
169+
Types.NestedField.optional("field")
170+
.withId(1)
171+
.ofType(type)
172+
.withWriteDefault(writeDefault)
173+
.withInitialDefault(initialDefault)
174+
.build());
175+
176+
StructType sparkSchema = SparkSchemaUtil.convert(schema);
177+
StructField defaultField = sparkSchema.fields()[0];
178+
Metadata metadata = defaultField.metadata();
179+
180+
assertThat(
181+
metadata.contains(
182+
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
183+
.as("Field of type %s should have CURRENT_DEFAULT metadata", type)
184+
.isTrue();
185+
assertThat(
186+
metadata.contains(
187+
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
188+
.as("Field of type %s should have EXISTS_DEFAULT metadata", type)
189+
.isTrue();
190+
assertThat(
191+
metadata.getString(
192+
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
193+
.as(
194+
"Spark metadata CURRENT_DEFAULT for type %s should contain correctly formatted literal",
195+
type)
196+
.isEqualTo(expectedCurrentDefaultValue);
197+
assertThat(
198+
metadata.getString(
199+
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
200+
.as(
201+
"Spark metadata EXISTS_DEFAULT for type %s should contain correctly formatted literal",
202+
type)
203+
.isEqualTo(expectedExistsDefaultValue);
204+
} finally {
205+
TimeZone.setDefault(systemTimeZone);
206+
}
207+
}
208+
209+
private static Stream<Arguments> schemaConversionWithDefaultsTestCases() {
210+
return Stream.of(
211+
Arguments.of(Types.IntegerType.get(), Literal.of(1), Literal.of(2), "1", "2"),
212+
Arguments.of(
213+
Types.StringType.get(),
214+
Literal.of("write_default"),
215+
Literal.of("initial_default"),
216+
"'write_default'",
217+
"'initial_default'"),
218+
Arguments.of(
219+
Types.UUIDType.get(),
220+
Literal.of("f79c3e09-677c-4bbd-a479-3f349cb785e7").to(Types.UUIDType.get()),
221+
Literal.of("f79c3e09-677c-4bbd-a479-3f349cb685e7").to(Types.UUIDType.get()),
222+
"'f79c3e09-677c-4bbd-a479-3f349cb785e7'",
223+
"'f79c3e09-677c-4bbd-a479-3f349cb685e7'"),
224+
Arguments.of(Types.BooleanType.get(), Literal.of(true), Literal.of(false), "true", "false"),
225+
Arguments.of(Types.IntegerType.get(), Literal.of(42), Literal.of(10), "42", "10"),
226+
Arguments.of(Types.LongType.get(), Literal.of(100L), Literal.of(50L), "100L", "50L"),
227+
Arguments.of(
228+
Types.FloatType.get(),
229+
Literal.of(3.14f),
230+
Literal.of(1.5f),
231+
"CAST('3.14' AS FLOAT)",
232+
"CAST('1.5' AS FLOAT)"),
233+
Arguments.of(
234+
Types.DoubleType.get(), Literal.of(2.718), Literal.of(1.414), "2.718D", "1.414D"),
235+
Arguments.of(
236+
Types.DecimalType.of(10, 2),
237+
Literal.of(new BigDecimal("99.99")),
238+
Literal.of(new BigDecimal("11.11")),
239+
"99.99BD",
240+
"11.11BD"),
241+
Arguments.of(
242+
Types.DateType.get(),
243+
Literal.of("2024-01-01").to(Types.DateType.get()),
244+
Literal.of("2023-01-01").to(Types.DateType.get()),
245+
"DATE '2024-01-01'",
246+
"DATE '2023-01-01'"),
247+
Arguments.of(
248+
Types.TimestampType.withZone(),
249+
Literal.of("2017-11-30T10:30:07.123456+00:00").to(Types.TimestampType.withZone()),
250+
Literal.of("2017-11-29T10:30:07.123456+00:00").to(Types.TimestampType.withZone()),
251+
"TIMESTAMP '2017-11-30 10:30:07.123456'",
252+
"TIMESTAMP '2017-11-29 10:30:07.123456'"),
253+
Arguments.of(
254+
Types.TimestampType.withoutZone(),
255+
Literal.of("2017-11-30T10:30:07.123456").to(Types.TimestampType.withoutZone()),
256+
Literal.of("2017-11-29T10:30:07.123456").to(Types.TimestampType.withoutZone()),
257+
"TIMESTAMP_NTZ '2017-11-30 10:30:07.123456'",
258+
"TIMESTAMP_NTZ '2017-11-29 10:30:07.123456'"),
259+
Arguments.of(
260+
Types.BinaryType.get(),
261+
Literal.of(ByteBuffer.wrap(new byte[] {0x0a, 0x0b})),
262+
Literal.of(ByteBuffer.wrap(new byte[] {0x01, 0x02})),
263+
"X'0A0B'",
264+
"X'0102'"),
265+
Arguments.of(
266+
Types.FixedType.ofLength(4),
267+
Literal.of("test".getBytes()),
268+
Literal.of("init".getBytes()),
269+
"X'74657374'",
270+
"X'696E6974'"));
271+
}
83272
}

0 commit comments

Comments
 (0)