Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.TypeUtil;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.expressions.Literal$;
import org.apache.spark.sql.types.ArrayType$;
import org.apache.spark.sql.types.BinaryType$;
import org.apache.spark.sql.types.BooleanType$;
Expand Down Expand Up @@ -69,6 +70,22 @@ public DataType struct(Types.StructType struct, List<DataType> fieldResults) {
if (field.doc() != null) {
sparkField = sparkField.withComment(field.doc());
}

// Convert both write and initial default values to Spark SQL string literal representations
// on the StructField metadata
if (field.writeDefault() != null) {
Object writeDefault = SparkUtil.internalToSpark(field.type(), field.writeDefault());
sparkField =
sparkField.withCurrentDefaultValue(Literal$.MODULE$.create(writeDefault, type).sql());
}

if (field.initialDefault() != null) {
Object initialDefault = SparkUtil.internalToSpark(field.type(), field.initialDefault());
sparkField =
sparkField.withExistenceDefaultValue(
Literal$.MODULE$.create(initialDefault, type).sql());
}

sparkFields.add(sparkField);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,15 +21,27 @@
import static org.apache.iceberg.types.Types.NestedField.optional;
import static org.assertj.core.api.Assertions.assertThat;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.List;
import java.util.TimeZone;
import java.util.stream.Stream;
import org.apache.iceberg.MetadataColumns;
import org.apache.iceberg.Schema;
import org.apache.iceberg.expressions.Literal;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.spark.sql.catalyst.expressions.AttributeReference;
import org.apache.spark.sql.catalyst.expressions.MetadataAttribute;
import org.apache.spark.sql.catalyst.types.DataTypeUtils;
import org.apache.spark.sql.catalyst.util.ResolveDefaultColumnsUtils$;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.params.ParameterizedTest;
import org.junit.jupiter.params.provider.Arguments;
import org.junit.jupiter.params.provider.MethodSource;

public class TestSparkSchemaUtil {
private static final Schema TEST_SCHEMA =
Expand Down Expand Up @@ -80,4 +92,181 @@ public void testSchemaConversionWithMetaDataColumnSchema() {
}
}
}

@Test
public void testSchemaConversionWithOnlyWriteDefault() {
Schema schema =
new Schema(
Types.NestedField.optional("field")
.withId(1)
.ofType(Types.StringType.get())
.withWriteDefault(Literal.of("write_only"))
.build());

StructType sparkSchema = SparkSchemaUtil.convert(schema);
Metadata metadata = sparkSchema.fields()[0].metadata();

assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field with only write default should have CURRENT_DEFAULT metadata")
.isTrue();
assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field with only write default should not have EXISTS_DEFAULT metadata")
.isFalse();
assertThat(
metadata.getString(
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
.as("Spark metadata CURRENT_DEFAULT should contain correctly formatted literal")
.isEqualTo("'write_only'");
}

@Test
public void testSchemaConversionWithOnlyInitialDefault() {
Schema schema =
new Schema(
Types.NestedField.optional("field")
.withId(1)
.ofType(Types.IntegerType.get())
.withInitialDefault(Literal.of(42))
.build());

StructType sparkSchema = SparkSchemaUtil.convert(schema);
Metadata metadata = sparkSchema.fields()[0].metadata();

assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field with only initial default should not have CURRENT_DEFAULT metadata")
.isFalse();
assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field with only initial default should have EXISTS_DEFAULT metadata")
.isTrue();
assertThat(
metadata.getString(
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
.as("Spark metadata EXISTS_DEFAULT should contain correctly formatted literal")
.isEqualTo("42");
}

@ParameterizedTest(name = "{0} with writeDefault={1}, initialDefault={2}")
@MethodSource("schemaConversionWithDefaultsTestCases")
public void testSchemaConversionWithDefaultsForPrimitiveTypes(
Type type,
Literal<?> writeDefault,
Literal<?> initialDefault,
String expectedCurrentDefaultValue,
String expectedExistsDefaultValue) {
TimeZone systemTimeZone = TimeZone.getDefault();
try {
TimeZone.setDefault(TimeZone.getTimeZone("UTC"));
Schema schema =
new Schema(
Types.NestedField.optional("field")
.withId(1)
.ofType(type)
.withWriteDefault(writeDefault)
.withInitialDefault(initialDefault)
.build());

StructType sparkSchema = SparkSchemaUtil.convert(schema);
StructField defaultField = sparkSchema.fields()[0];
Metadata metadata = defaultField.metadata();

assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field of type %s should have CURRENT_DEFAULT metadata", type)
.isTrue();
assertThat(
metadata.contains(
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
.as("Field of type %s should have EXISTS_DEFAULT metadata", type)
.isTrue();
assertThat(
metadata.getString(
ResolveDefaultColumnsUtils$.MODULE$.CURRENT_DEFAULT_COLUMN_METADATA_KEY()))
.as(
"Spark metadata CURRENT_DEFAULT for type %s should contain correctly formatted literal",
type)
.isEqualTo(expectedCurrentDefaultValue);
assertThat(
metadata.getString(
ResolveDefaultColumnsUtils$.MODULE$.EXISTS_DEFAULT_COLUMN_METADATA_KEY()))
.as(
"Spark metadata EXISTS_DEFAULT for type %s should contain correctly formatted literal",
type)
.isEqualTo(expectedExistsDefaultValue);
} finally {
TimeZone.setDefault(systemTimeZone);
}
}

private static Stream<Arguments> schemaConversionWithDefaultsTestCases() {
return Stream.of(
Arguments.of(Types.IntegerType.get(), Literal.of(1), Literal.of(2), "1", "2"),
Arguments.of(
Types.StringType.get(),
Literal.of("write_default"),
Literal.of("initial_default"),
"'write_default'",
"'initial_default'"),
Arguments.of(
Types.UUIDType.get(),
Literal.of("f79c3e09-677c-4bbd-a479-3f349cb785e7").to(Types.UUIDType.get()),
Literal.of("f79c3e09-677c-4bbd-a479-3f349cb685e7").to(Types.UUIDType.get()),
"'f79c3e09-677c-4bbd-a479-3f349cb785e7'",
"'f79c3e09-677c-4bbd-a479-3f349cb685e7'"),
Arguments.of(Types.BooleanType.get(), Literal.of(true), Literal.of(false), "true", "false"),
Arguments.of(Types.IntegerType.get(), Literal.of(42), Literal.of(10), "42", "10"),
Arguments.of(Types.LongType.get(), Literal.of(100L), Literal.of(50L), "100L", "50L"),
Arguments.of(
Types.FloatType.get(),
Literal.of(3.14f),
Literal.of(1.5f),
"CAST('3.14' AS FLOAT)",
"CAST('1.5' AS FLOAT)"),
Arguments.of(
Types.DoubleType.get(), Literal.of(2.718), Literal.of(1.414), "2.718D", "1.414D"),
Arguments.of(
Types.DecimalType.of(10, 2),
Literal.of(new BigDecimal("99.99")),
Literal.of(new BigDecimal("11.11")),
"99.99BD",
"11.11BD"),
Arguments.of(
Types.DateType.get(),
Literal.of("2024-01-01").to(Types.DateType.get()),
Literal.of("2023-01-01").to(Types.DateType.get()),
"DATE '2024-01-01'",
"DATE '2023-01-01'"),
Arguments.of(
Types.TimestampType.withZone(),
Literal.of("2017-11-30T10:30:07.123456+00:00").to(Types.TimestampType.withZone()),
Literal.of("2017-11-29T10:30:07.123456+00:00").to(Types.TimestampType.withZone()),
"TIMESTAMP '2017-11-30 10:30:07.123456'",
"TIMESTAMP '2017-11-29 10:30:07.123456'"),
Arguments.of(
Types.TimestampType.withoutZone(),
Literal.of("2017-11-30T10:30:07.123456").to(Types.TimestampType.withoutZone()),
Literal.of("2017-11-29T10:30:07.123456").to(Types.TimestampType.withoutZone()),
"TIMESTAMP_NTZ '2017-11-30 10:30:07.123456'",
"TIMESTAMP_NTZ '2017-11-29 10:30:07.123456'"),
Arguments.of(
Types.BinaryType.get(),
Literal.of(ByteBuffer.wrap(new byte[] {0x0a, 0x0b})),
Literal.of(ByteBuffer.wrap(new byte[] {0x01, 0x02})),
"X'0A0B'",
"X'0102'"),
Arguments.of(
Types.FixedType.ofLength(4),
Literal.of("test".getBytes()),
Literal.of("init".getBytes()),
"X'74657374'",
"X'696E6974'"));
}
}
Loading