Skip to content

Commit

Permalink
(scio-smb) Support mixed FileOperations per BucketedInput (#5064)
Browse files Browse the repository at this point in the history
  • Loading branch information
clairemcginty authored Jan 8, 2024
1 parent 07745c4 commit f77249a
Show file tree
Hide file tree
Showing 20 changed files with 459 additions and 205 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,8 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.concurrent.atomic.AtomicReference;
import javax.annotation.Nullable;
import org.apache.avro.Schema;
Expand All @@ -34,6 +36,7 @@
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.DisplayData.Builder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;

/**
* {@link org.apache.beam.sdk.extensions.smb.BucketMetadata} for Avro {@link IndexedRecord} records.
Expand Down Expand Up @@ -131,6 +134,21 @@ public Map<Class<?>, Coder<?>> coderOverrides() {
return AvroUtils.coderOverrides();
}

@Override
int hashPrimaryKeyMetadata() {
return Objects.hash(keyField, getKeyClass());
}

@Override
int hashSecondaryKeyMetadata() {
return Objects.hash(keyFieldSecondary, getKeyClassSecondary());
}

@Override
public Set<Class<? extends BucketMetadata>> compatibleMetadataTypes() {
return ImmutableSet.of(ParquetBucketMetadata.class);
}

@Override
public K1 extractKeyPrimary(V value) {
int[] path = keyPath.get();
Expand Down Expand Up @@ -175,28 +193,4 @@ public void populateDisplayData(Builder builder) {
if (keyFieldSecondary != null)
builder.add(DisplayData.item("keyFieldSecondary", keyFieldSecondary));
}

@Override
public boolean isPartitionCompatibleForPrimaryKey(BucketMetadata o) {
if (o == null || getClass() != o.getClass()) return false;
AvroBucketMetadata<?, ?, ?> that = (AvroBucketMetadata<?, ?, ?>) o;
return getKeyClass() == that.getKeyClass() && keyField.equals(that.keyField);
}

@Override
public boolean isPartitionCompatibleForPrimaryAndSecondaryKey(BucketMetadata o) {
if (o == null || getClass() != o.getClass()) return false;
AvroBucketMetadata<?, ?, ?> that = (AvroBucketMetadata<?, ?, ?>) o;
boolean allSecondaryPresent =
getKeyClassSecondary() != null
&& that.getKeyClassSecondary() != null
&& keyFieldSecondary != null
&& that.keyFieldSecondary != null;
// you messed up
if (!allSecondaryPresent) return false;
return getKeyClass() == that.getKeyClass()
&& getKeyClassSecondary() == that.getKeyClassSecondary()
&& keyField.equals(that.keyField)
&& keyFieldSecondary.equals(that.keyFieldSecondary);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;

/** API for reading and writing Avro sorted-bucket files. */
public class AvroSortedBucketIO {
Expand Down Expand Up @@ -191,6 +192,11 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

@Nullable
abstract Schema getSchema();

Expand Down Expand Up @@ -242,7 +248,7 @@ public Read<T> withPredicate(Predicate<T> predicate) {
}

@Override
protected SortedBucketSource.BucketedInput<T> toBucketedInput(
public SortedBucketSource.BucketedInput<T> toBucketedInput(
final SortedBucketSource.Keying keying) {
@SuppressWarnings("unchecked")
final AvroFileOperations<T> fileOperations =
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -290,14 +290,47 @@ <K> byte[] encodeKeyBytes(K key, Coder<K> coder) {
}

// Checks for complete equality between BucketMetadatas originating from the same BucketedInput
public abstract boolean isPartitionCompatibleForPrimaryKey(BucketMetadata other);
public boolean isPartitionCompatibleForPrimaryKey(BucketMetadata other) {
return isIntraPartitionCompatibleWith(other, false);
}

public boolean isPartitionCompatibleForPrimaryAndSecondaryKey(BucketMetadata other) {
return isIntraPartitionCompatibleWith(other, true);
}

public abstract boolean isPartitionCompatibleForPrimaryAndSecondaryKey(BucketMetadata other);
private <MetadataT extends BucketMetadata> boolean isIntraPartitionCompatibleWith(
MetadataT other, boolean checkSecondaryKeys) {
if (other == null) {
return false;
}
final Class<? extends BucketMetadata> otherClass = other.getClass();
final Set<Class<? extends BucketMetadata>> compatibleTypes = compatibleMetadataTypes();

if (compatibleTypes.isEmpty() && other.getClass() != this.getClass()) {
return false;
} else if (this.getKeyClass() != other.getKeyClass()
&& !(compatibleTypes.contains(otherClass)
&& (other.compatibleMetadataTypes().contains(this.getClass())))) {
return false;
}

return (this.hashPrimaryKeyMetadata() == other.hashPrimaryKeyMetadata()
&& (!checkSecondaryKeys
|| this.hashSecondaryKeyMetadata() == other.hashSecondaryKeyMetadata()));
}

public Set<Class<? extends BucketMetadata>> compatibleMetadataTypes() {
return new HashSet<>();
}

public abstract K1 extractKeyPrimary(V value);

public abstract K2 extractKeySecondary(V value);

abstract int hashPrimaryKeyMetadata();

abstract int hashSecondaryKeyMetadata();

public SortedBucketIO.ComparableKeyBytes primaryComparableKeyBytes(V value) {
return new SortedBucketIO.ComparableKeyBytes(getKeyBytesPrimary(value), null);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,8 @@
import java.util.concurrent.ConcurrentHashMap;
import java.util.function.BiFunction;
import org.apache.beam.sdk.extensions.smb.SMBFilenamePolicy.FileAssignment;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.annotations.VisibleForTesting;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

Expand Down Expand Up @@ -69,15 +69,14 @@ int leastNumBuckets() {
this.batchSize = batchSize;
}

private <V> Map<ResourceId, BucketMetadata<?, ?, V>> fetchMetadata(List<String> directories) {
private <V> Map<ResourceId, BucketMetadata<?, ?, V>> fetchMetadata(List<ResourceId> directories) {
final int total = directories.size();
final Map<ResourceId, BucketMetadata<?, ?, V>> metadata = new ConcurrentHashMap<>();
int start = 0;
while (start < total) {
directories.stream()
.skip(start)
.limit(batchSize)
.map(dir -> FileSystems.matchNewResource(dir, true))
.parallel()
.forEach(dir -> metadata.put(dir, BucketMetadata.get(dir)));
start += batchSize;
Expand All @@ -86,11 +85,11 @@ int leastNumBuckets() {
}

private <V> SourceMetadata<V> getSourceMetadata(
List<String> directories,
String filenameSuffix,
Map<ResourceId, KV<String, FileOperations<V>>> directories,
BiFunction<BucketMetadata<?, ?, V>, BucketMetadata<?, ?, V>, Boolean>
compatibilityCompareFn) {
final Map<ResourceId, BucketMetadata<?, ?, V>> bucketMetadatas = fetchMetadata(directories);
final Map<ResourceId, BucketMetadata<?, ?, V>> bucketMetadatas =
fetchMetadata(new ArrayList<>(directories.keySet()));
Preconditions.checkState(!bucketMetadatas.isEmpty(), "Failed to find metadata");

Map<ResourceId, SourceMetadataValue<V>> mapping = new HashMap<>();
Expand All @@ -107,24 +106,22 @@ private <V> SourceMetadata<V> getSourceMetadata(
metadata,
first.getValue());
final FileAssignment fileAssignment =
new SMBFilenamePolicy(dir, metadata.getFilenamePrefix(), filenameSuffix)
new SMBFilenamePolicy(
dir, metadata.getFilenamePrefix(), directories.get(dir).getKey())
.forDestination();
mapping.put(dir, new SourceMetadataValue<>(metadata, fileAssignment));
});
return new SourceMetadata<>(mapping);
}

public <V> SourceMetadata<V> getPrimaryKeyedSourceMetadata(
List<String> directories, String filenameSuffix) {
return getSourceMetadata(
directories, filenameSuffix, BucketMetadata::isPartitionCompatibleForPrimaryKey);
Map<ResourceId, KV<String, FileOperations<V>>> directories) {
return getSourceMetadata(directories, BucketMetadata::isPartitionCompatibleForPrimaryKey);
}

public <V> SourceMetadata<V> getPrimaryAndSecondaryKeyedSourceMetadata(
List<String> directories, String filenameSuffix) {
Map<ResourceId, KV<String, FileOperations<V>>> directories) {
return getSourceMetadata(
directories,
filenameSuffix,
BucketMetadata::isPartitionCompatibleForPrimaryAndSecondaryKey);
directories, BucketMetadata::isPartitionCompatibleForPrimaryAndSecondaryKey);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.api.services.bigquery.model.TableRow;
import java.util.Arrays;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.beam.sdk.coders.CannotProvideCoderException;
import org.apache.beam.sdk.transforms.display.DisplayData;
Expand Down Expand Up @@ -148,32 +148,12 @@ public void populateDisplayData(Builder builder) {
}

@Override
public boolean isPartitionCompatibleForPrimaryKey(final BucketMetadata o) {
if (o == null || getClass() != o.getClass()) return false;
JsonBucketMetadata<?, ?> that = (JsonBucketMetadata<?, ?>) o;
return getKeyClass() == that.getKeyClass()
&& keyField.equals(that.keyField)
&& Arrays.equals(keyPath, that.keyPath);
int hashPrimaryKeyMetadata() {
return Objects.hash(keyField, getKeyClass());
}

@Override
public boolean isPartitionCompatibleForPrimaryAndSecondaryKey(final BucketMetadata o) {
if (o == null || getClass() != o.getClass()) return false;
JsonBucketMetadata<?, ?> that = (JsonBucketMetadata<?, ?>) o;
boolean allSecondaryPresent =
getKeyClassSecondary() != null
&& that.getKeyClassSecondary() != null
&& keyFieldSecondary != null
&& that.keyFieldSecondary != null
&& keyPathSecondary != null
&& that.keyPathSecondary != null;
// you messed up
if (!allSecondaryPresent) return false;
return getKeyClass() == that.getKeyClass()
&& getKeyClassSecondary() == that.getKeyClassSecondary()
&& keyField.equals(that.keyField)
&& keyFieldSecondary.equals(that.keyFieldSecondary)
&& Arrays.equals(keyPath, that.keyPath)
&& Arrays.equals(keyPathSecondary, that.keyPathSecondary);
int hashSecondaryKeyMetadata() {
return Objects.hash(keyFieldSecondary, getKeyClassSecondary());
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;

/** API for reading and writing BigQuery {@link TableRow} JSON sorted-bucket files. */
public class JsonSortedBucketIO {
Expand Down Expand Up @@ -107,6 +108,10 @@ public static <K1, K2> TransformOutput<K1, K2> transformOutput(
*/
@AutoValue
public abstract static class Read extends SortedBucketIO.Read<TableRow> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

abstract Compression getCompression();

Expand Down Expand Up @@ -152,7 +157,7 @@ public Read withPredicate(Predicate<TableRow> predicate) {
}

@Override
protected BucketedInput<TableRow> toBucketedInput(final SortedBucketSource.Keying keying) {
public BucketedInput<TableRow> toBucketedInput(final SortedBucketSource.Keying keying) {
return BucketedInput.of(
keying,
getTupleTag(),
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,18 +52,22 @@
*/
public class ParquetAvroFileOperations<ValueT> extends FileOperations<ValueT> {
static final CompressionCodecName DEFAULT_COMPRESSION = CompressionCodecName.ZSTD;

private final Class<ValueT> recordClass;
private final SerializableSchemaSupplier schemaSupplier;
private final CompressionCodecName compression;
private final SerializableConfiguration conf;
private final FilterPredicate predicate;

private ParquetAvroFileOperations(
Schema schema,
Class<ValueT> recordClass,
CompressionCodecName compression,
Configuration conf,
FilterPredicate predicate) {
super(Compression.UNCOMPRESSED, MimeTypes.BINARY);
this.schemaSupplier = new SerializableSchemaSupplier(schema);
this.recordClass = recordClass;
this.compression = compression;
this.conf = new SerializableConfiguration(conf);
this.predicate = predicate;
Expand All @@ -80,7 +84,7 @@ public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(

public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Schema schema, CompressionCodecName compression, Configuration conf) {
return new ParquetAvroFileOperations<>(schema, compression, conf, null);
return new ParquetAvroFileOperations<>(schema, null, compression, conf, null);
}

public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Expand All @@ -90,7 +94,7 @@ public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(

public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Schema schema, FilterPredicate predicate, Configuration conf) {
return new ParquetAvroFileOperations<>(schema, DEFAULT_COMPRESSION, conf, predicate);
return new ParquetAvroFileOperations<>(schema, null, DEFAULT_COMPRESSION, conf, predicate);
}

public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(Class<V> recordClass) {
Expand All @@ -106,7 +110,7 @@ public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Class<V> recordClass, CompressionCodecName compression, Configuration conf) {
// Use reflection to get SR schema
final Schema schema = new ReflectData(recordClass.getClassLoader()).getSchema(recordClass);
return new ParquetAvroFileOperations<>(schema, compression, conf, null);
return new ParquetAvroFileOperations<>(schema, recordClass, compression, conf, null);
}

public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Expand All @@ -118,7 +122,8 @@ public static <V extends IndexedRecord> ParquetAvroFileOperations<V> of(
Class<V> recordClass, FilterPredicate predicate, Configuration conf) {
// Use reflection to get SR schema
final Schema schema = new ReflectData(recordClass.getClassLoader()).getSchema(recordClass);
return new ParquetAvroFileOperations<>(schema, DEFAULT_COMPRESSION, conf, predicate);
return new ParquetAvroFileOperations<>(
schema, recordClass, DEFAULT_COMPRESSION, conf, predicate);
}

@Override
Expand All @@ -141,7 +146,9 @@ protected FileIO.Sink<ValueT> createSink() {
@SuppressWarnings("unchecked")
@Override
public Coder<ValueT> getCoder() {
return (AvroCoder<ValueT>) AvroCoder.of(getSchema());
return recordClass == null
? (AvroCoder<ValueT>) AvroCoder.of(getSchema())
: AvroCoder.reflect(recordClass);
}

Schema getSchema() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
import org.apache.hadoop.conf.Configuration;
import org.apache.parquet.filter2.predicate.FilterPredicate;
import org.apache.parquet.hadoop.metadata.CompressionCodecName;
Expand Down Expand Up @@ -199,6 +200,11 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

@Nullable
abstract Schema getSchema();

Expand Down Expand Up @@ -269,7 +275,7 @@ public Read<T> withConfiguration(Configuration configuration) {
}

@Override
protected BucketedInput<T> toBucketedInput(final SortedBucketSource.Keying keying) {
public BucketedInput<T> toBucketedInput(final SortedBucketSource.Keying keying) {
final Schema schema =
getRecordClass() == null
? getSchema()
Expand Down
Loading

0 comments on commit f77249a

Please sign in to comment.