From e2d7ae80f90553efb4f21ccfc89ddb3d75ce5146 Mon Sep 17 00:00:00 2001 From: Claire McGinty Date: Wed, 6 Mar 2024 08:34:25 -0500 Subject: [PATCH] Fix for Parquet-Avro --- .../extensions/smb/ParquetBucketMetadata.java | 18 ++++++ .../scio/smb/SmbVersionParityTest.scala | 61 +++++++++++++------ 2 files changed, 61 insertions(+), 18 deletions(-) diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java index 4fcd701704..2016c1258c 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetBucketMetadata.java @@ -228,6 +228,24 @@ static K extractKey(Method[] keyGetters, Object value) { return key; } + @Override + boolean keyClassMatches(Class requestedReadType) { + if (requestedReadType == String.class && getKeyClass() == CharSequence.class) { + return true; + } else { + return super.keyClassMatches(requestedReadType); + } + } + + @Override + boolean keyClassSecondaryMatches(Class requestedReadType) { + if (requestedReadType == String.class && getKeyClassSecondary() == CharSequence.class) { + return true; + } else { + return super.keyClassSecondaryMatches(requestedReadType); + } + } + //////////////////////////////////////////////////////////////////////////////// // Logic for dealing with Avro records vs Scala case classes //////////////////////////////////////////////////////////////////////////////// diff --git a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala index 2849cf5f87..83c311de6a 100644 --- a/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala +++ b/scio-smb/src/test/scala/com/spotify/scio/smb/SmbVersionParityTest.scala @@ -3,16 +3,20 @@ package com.spotify.scio.smb import com.spotify.scio.avro.{Account, AccountStatus} import com.spotify.scio.ScioContext import com.spotify.scio.testing.PipelineSpec -import org.apache.beam.sdk.extensions.smb.AvroSortedBucketIO +import org.apache.beam.sdk.extensions.smb.{ + AvroSortedBucketIO, + ParquetAvroSortedBucketIO, + SortedBucketIO +} import org.apache.beam.sdk.values.TupleTag import java.nio.file.Files class SmbVersionParityTest extends PipelineSpec { - "SortedBucketSource" should "be able to read CharSequence-keyed sources written before 0.14" in { - val output = Files.createTempDirectory("smb-version-test").toFile - output.deleteOnExit() - + private def testRoundtrip( + write: SortedBucketIO.Write[CharSequence, _, Account], + read: SortedBucketIO.Read[Account] + ): Unit = { val accounts = (1 to 10).map { i => Account .newBuilder() @@ -27,25 +31,14 @@ class SmbVersionParityTest extends PipelineSpec { { val sc = ScioContext() sc.parallelize(accounts) - .saveAsSortedBucket( - AvroSortedBucketIO - .write(classOf[CharSequence], "name", classOf[Account]) - .to(output.getAbsolutePath) - .withNumBuckets(1) - .withNumShards(1) - ) + .saveAsSortedBucket(write) sc.run() } // Read data val sc = ScioContext() val tap = sc - .sortMergeGroupByKey( - classOf[String], - AvroSortedBucketIO - .read(new TupleTag[Account], classOf[Account]) - .from(output.getAbsolutePath) - ) + .sortMergeGroupByKey(classOf[String], read) .materialize tap .get(sc.run().waitUntilDone()) @@ -53,4 +46,36 @@ class SmbVersionParityTest extends PipelineSpec { .value .toSeq should contain theSameElementsAs accounts } + + "SortedBucketSource" should "be able to read CharSequence-keyed Avro sources written before 0.14" in { + val output = Files.createTempDirectory("smb-version-test-avro").toFile + output.deleteOnExit() + + testRoundtrip( + AvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(output.getAbsolutePath) + .withNumBuckets(1) + .withNumShards(1), + AvroSortedBucketIO + .read(new TupleTag[Account], classOf[Account]) + .from(output.getAbsolutePath) + ) + } + + it should "be able to read CharSequence-keyed Parquet sources written before 0.14" in { + val output = Files.createTempDirectory("smb-version-test-parquet").toFile + output.deleteOnExit() + + testRoundtrip( + ParquetAvroSortedBucketIO + .write(classOf[CharSequence], "name", classOf[Account]) + .to(output.getAbsolutePath) + .withNumBuckets(1) + .withNumShards(1), + ParquetAvroSortedBucketIO + .read(new TupleTag[Account], classOf[Account]) + .from(output.getAbsolutePath) + ) + } }