Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Spark 3.5: Support default values in Parquet reader #11803

Merged
merged 1 commit into from
Dec 18, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -18,36 +18,47 @@
*/
package org.apache.iceberg.spark;

import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.sql.Date;
import java.sql.Timestamp;
import java.util.List;
import java.util.Map;
import java.util.function.BiFunction;
import java.util.function.Function;
import java.util.stream.Collectors;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.hadoop.conf.Configuration;
import org.apache.iceberg.PartitionField;
import org.apache.iceberg.PartitionSpec;
import org.apache.iceberg.StructLike;
import org.apache.iceberg.relocated.com.google.common.base.Joiner;
import org.apache.iceberg.relocated.com.google.common.base.Preconditions;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.transforms.Transform;
import org.apache.iceberg.transforms.UnknownTransform;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.Pair;
import org.apache.spark.SparkEnv;
import org.apache.spark.scheduler.ExecutorCacheTaskLocation;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.catalyst.expressions.BoundReference;
import org.apache.spark.sql.catalyst.expressions.EqualTo;
import org.apache.spark.sql.catalyst.expressions.Expression;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.catalyst.expressions.Literal;
import org.apache.spark.sql.connector.expressions.NamedReference;
import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.storage.BlockManager;
import org.apache.spark.storage.BlockManagerId;
import org.apache.spark.storage.BlockManagerMaster;
import org.apache.spark.unsafe.types.UTF8String;
import org.joda.time.DateTime;
import scala.collection.JavaConverters;
import scala.collection.Seq;
Expand Down Expand Up @@ -268,4 +279,59 @@ private static <T> List<T> toJavaList(Seq<T> seq) {
private static String toExecutorLocation(BlockManagerId id) {
return ExecutorCacheTaskLocation.apply(id.host(), id.executorId()).toString();
}

/**
* Converts a value to pass into Spark from Iceberg's internal object model.
*
* @param type an Iceberg type
* @param value a value that is an instance of {@link Type.TypeID#javaClass()}
* @return the value converted for Spark
*/
public static Object convertConstant(Type type, Object value) {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I know that this is a copy of the other one, but I would also expect UUID to be here.

Copy link
Contributor Author

@rdblue rdblue Dec 18, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Good point, I'll update it in my follow up that is fixing a few types and adding a test for each primitive. I need to make some changes on top of this one, so it makes sense to do all of the primitive type fixes at the same time.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The follow up will be here: #11811

if (value == null) {
return null;
}

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
case STRUCT:
Types.StructType structType = (Types.StructType) type;

if (structType.fields().isEmpty()) {
return new GenericInternalRow();
}

List<Types.NestedField> fields = structType.fields();
Object[] values = new Object[fields.size()];
StructLike struct = (StructLike) value;

for (int index = 0; index < fields.size(); index++) {
Types.NestedField field = fields.get(index);
Type fieldType = field.type();
values[index] =
convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass()));
}

return new GenericInternalRow(values);
default:
}

return value;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableMap;
import org.apache.iceberg.relocated.com.google.common.collect.Lists;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Type.TypeID;
import org.apache.iceberg.types.Types;
import org.apache.iceberg.util.UUIDUtil;
Expand Down Expand Up @@ -165,6 +166,7 @@ public ParquetValueReader<?> struct(
int defaultMaxDefinitionLevel = type.getMaxDefinitionLevel(currentPath());
for (Types.NestedField field : expectedFields) {
int id = field.fieldId();
ParquetValueReader<?> reader = readersById.get(id);
if (idToConstant.containsKey(id)) {
// containsKey is used because the constant may be null
int fieldMaxDefinitionLevel =
Expand All @@ -178,15 +180,21 @@ public ParquetValueReader<?> struct(
} else if (id == MetadataColumns.IS_DELETED.fieldId()) {
reorderedFields.add(ParquetValueReaders.constant(false));
types.add(null);
} else if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else if (field.initialDefault() != null) {
reorderedFields.add(
ParquetValueReaders.constant(
SparkUtil.convertConstant(field.type(), field.initialDefault()),
maxDefinitionLevelsById.getOrDefault(id, defaultMaxDefinitionLevel)));
types.add(typesById.get(id));
} else if (field.isOptional()) {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
} else {
ParquetValueReader<?> reader = readersById.get(id);
if (reader != null) {
reorderedFields.add(reader);
types.add(typesById.get(id));
} else {
reorderedFields.add(ParquetValueReaders.nulls());
types.add(null);
}
throw new IllegalArgumentException(
String.format("Missing required field: %s", field.name()));
}
}

Expand Down Expand Up @@ -250,7 +258,7 @@ public ParquetValueReader<?> primitive(
if (expected != null && expected.typeId() == Types.LongType.get().typeId()) {
return new IntAsLongReader(desc);
} else {
return new UnboxedReader(desc);
return new UnboxedReader<>(desc);
}
case DATE:
case INT_64:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,17 +20,13 @@

import java.io.Closeable;
import java.io.IOException;
import java.math.BigDecimal;
import java.nio.ByteBuffer;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import java.util.stream.Stream;
import org.apache.avro.generic.GenericData;
import org.apache.avro.util.Utf8;
import org.apache.iceberg.ContentFile;
import org.apache.iceberg.ContentScanTask;
import org.apache.iceberg.DeleteFile;
Expand All @@ -53,16 +49,11 @@
import org.apache.iceberg.mapping.NameMappingParser;
import org.apache.iceberg.spark.SparkExecutorCache;
import org.apache.iceberg.spark.SparkSchemaUtil;
import org.apache.iceberg.types.Type;
import org.apache.iceberg.types.Types.NestedField;
import org.apache.iceberg.spark.SparkUtil;
import org.apache.iceberg.types.Types.StructType;
import org.apache.iceberg.util.ByteBuffers;
import org.apache.iceberg.util.PartitionUtil;
import org.apache.spark.rdd.InputFileBlockHolder;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.catalyst.expressions.GenericInternalRow;
import org.apache.spark.sql.types.Decimal;
import org.apache.spark.unsafe.types.UTF8String;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

Expand Down Expand Up @@ -193,59 +184,12 @@ private Map<String, InputFile> inputFiles() {
protected Map<Integer, ?> constantsMap(ContentScanTask<?> task, Schema readSchema) {
if (readSchema.findField(MetadataColumns.PARTITION_COLUMN_ID) != null) {
StructType partitionType = Partitioning.partitionType(table);
return PartitionUtil.constantsMap(task, partitionType, BaseReader::convertConstant);
return PartitionUtil.constantsMap(task, partitionType, SparkUtil::convertConstant);
} else {
return PartitionUtil.constantsMap(task, BaseReader::convertConstant);
return PartitionUtil.constantsMap(task, SparkUtil::convertConstant);
}
}

protected static Object convertConstant(Type type, Object value) {
if (value == null) {
return null;
}

switch (type.typeId()) {
case DECIMAL:
return Decimal.apply((BigDecimal) value);
case STRING:
if (value instanceof Utf8) {
Utf8 utf8 = (Utf8) value;
return UTF8String.fromBytes(utf8.getBytes(), 0, utf8.getByteLength());
}
return UTF8String.fromString(value.toString());
case FIXED:
if (value instanceof byte[]) {
return value;
} else if (value instanceof GenericData.Fixed) {
return ((GenericData.Fixed) value).bytes();
}
return ByteBuffers.toByteArray((ByteBuffer) value);
case BINARY:
return ByteBuffers.toByteArray((ByteBuffer) value);
case STRUCT:
StructType structType = (StructType) type;

if (structType.fields().isEmpty()) {
return new GenericInternalRow();
}

List<NestedField> fields = structType.fields();
Object[] values = new Object[fields.size()];
StructLike struct = (StructLike) value;

for (int index = 0; index < fields.size(); index++) {
NestedField field = fields.get(index);
Type fieldType = field.type();
values[index] =
convertConstant(fieldType, struct.get(index, fieldType.typeId().javaClass()));
}

return new GenericInternalRow(values);
default:
}
return value;
}

protected class SparkDeleteFilter extends DeleteFilter<InternalRow> {
private final InternalRowWrapper asStructLike;

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
import java.util.stream.Collectors;
import java.util.stream.Stream;
import java.util.stream.StreamSupport;
import org.apache.avro.Schema.Field;
import org.apache.avro.generic.GenericData;
import org.apache.avro.generic.GenericData.Record;
import org.apache.iceberg.DataFile;
Expand Down Expand Up @@ -246,7 +247,7 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual)
assertThat(expected).as("Should expect a Collection").isInstanceOf(Collection.class);
assertThat(actual).as("Should be a Seq").isInstanceOf(Seq.class);
List<?> asList = seqAsJavaListConverter((Seq<?>) actual).asJava();
assertEqualsSafe(type.asNestedType().asListType(), (Collection) expected, asList);
assertEqualsSafe(type.asNestedType().asListType(), (Collection<?>) expected, asList);
break;
case MAP:
assertThat(expected).as("Should expect a Collection").isInstanceOf(Map.class);
Expand All @@ -263,11 +264,20 @@ private static void assertEqualsSafe(Type type, Object expected, Object actual)

public static void assertEqualsUnsafe(Types.StructType struct, Record rec, InternalRow row) {
List<Types.NestedField> fields = struct.fields();
for (int i = 0; i < fields.size(); i += 1) {
Type fieldType = fields.get(i).type();

Object expectedValue = rec.get(i);
Object actualValue = row.isNullAt(i) ? null : row.get(i, convert(fieldType));
for (int readPos = 0; readPos < fields.size(); readPos += 1) {
Types.NestedField field = fields.get(readPos);
Field writeField = rec.getSchema().getField(field.name());

Type fieldType = field.type();
Object actualValue = row.isNullAt(readPos) ? null : row.get(readPos, convert(fieldType));

Object expectedValue;
if (writeField != null) {
int writePos = writeField.pos();
expectedValue = rec.get(writePos);
} else {
expectedValue = field.initialDefault();
}

assertEqualsUnsafe(fieldType, expectedValue, actualValue);
}
Expand Down
Loading