From 434d5588bfe9ff948a06c4945d4913059629a2cb Mon Sep 17 00:00:00 2001 From: Michel Davit Date: Mon, 28 Aug 2023 16:14:29 +0200 Subject: [PATCH] TestIO proposal for SMB --- build.sbt | 2 +- .../scala/com/spotify/scio/io/ScioIO.scala | 4 ++ .../extensions/smb/AvroSortedBucketIO.java | 5 -- .../extensions/smb/JsonSortedBucketIO.java | 4 -- .../smb/ParquetAvroSortedBucketIO.java | 5 -- .../sdk/extensions/smb/SortedBucketIO.java | 7 ++ .../extensions/smb/TensorFlowBucketIO.java | 4 -- .../scala/com/spotify/scio/smb/SMBIO.scala | 8 +++ .../SortMergeBucketScioContextSyntax.scala | 65 +++++++++++-------- .../smb/ParquetTypeSortedBucketIO.scala | 4 ++ .../scio/smb/SortMergeBucketTest.scala | 56 ++++++++++++++++ .../com/spotify/scio/testing/JobTest.scala | 8 ++- 12 files changed, 125 insertions(+), 47 deletions(-) create mode 100644 scio-smb/src/main/scala/com/spotify/scio/smb/SMBIO.scala create mode 100644 scio-smb/src/test/scala/com/spotify/scio/smb/SortMergeBucketTest.scala diff --git a/build.sbt b/build.sbt index ef62dbba51..2d22a94acd 100644 --- a/build.sbt +++ b/build.sbt @@ -1322,7 +1322,7 @@ lazy val `scio-smb`: Project = project .in(file("scio-smb")) .dependsOn( `scio-core`, - `scio-test` % "test;it", + `scio-test` % "test->test;it", `scio-avro` % IntegrationTest ) .configs(IntegrationTest) diff --git a/scio-core/src/main/scala/com/spotify/scio/io/ScioIO.scala b/scio-core/src/main/scala/com/spotify/scio/io/ScioIO.scala index e91104cdc5..18d2062f6f 100644 --- a/scio-core/src/main/scala/com/spotify/scio/io/ScioIO.scala +++ b/scio-core/src/main/scala/com/spotify/scio/io/ScioIO.scala @@ -159,6 +159,10 @@ trait TestIO[T] extends ScioIO[T] { throw new UnsupportedOperationException(s"$this is for testing purpose only") } +trait KeyedIO[K, T] extends ScioIO[(K, T)] { + def keyBy: T => K +} + /** * Special version of [[ScioIO]] for use with [[ScioContext.customInput]] and * [[SCollection.saveAsCustomOutput]]. diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java index a35c704128..18e79fb932 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java @@ -192,11 +192,6 @@ public static TransformOutput tran /** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */ @AutoValue public abstract static class Read extends SortedBucketIO.Read { - @Nullable - abstract ImmutableList getInputDirectories(); - - abstract String getFilenameSuffix(); - @Nullable abstract Schema getSchema(); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java index 443faa1496..49b3da26e6 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java @@ -108,10 +108,6 @@ public static TransformOutput transformOutput( */ @AutoValue public abstract static class Read extends SortedBucketIO.Read { - @Nullable - abstract ImmutableList getInputDirectories(); - - abstract String getFilenameSuffix(); abstract Compression getCompression(); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java index 688e686d94..e689e60968 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java @@ -200,11 +200,6 @@ public static TransformOutput tran /** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */ @AutoValue public abstract static class Read extends SortedBucketIO.Read { - @Nullable - abstract ImmutableList getInputDirectories(); - - abstract String getFilenameSuffix(); - @Nullable abstract Schema getSchema(); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java index d1bbb584db..06a6dbacbe 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java @@ -493,6 +493,13 @@ public abstract static class TransformOutput implements Serializable /** Represents a single sorted-bucket source written using {@link SortedBucketSink}. */ public abstract static class Read implements Serializable { + + @Nullable + public abstract ImmutableList getInputDirectories(); + + + abstract String getFilenameSuffix(); + public abstract TupleTag getTupleTag(); protected abstract BucketedInput toBucketedInput(SortedBucketSource.Keying keying); diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java index 38e5188951..7e58ea5241 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java @@ -125,10 +125,6 @@ public static TransformOutput transformOutput( */ @AutoValue public abstract static class Read extends SortedBucketIO.Read { - @Nullable - abstract ImmutableList getInputDirectories(); - - abstract String getFilenameSuffix(); abstract Compression getCompression(); diff --git a/scio-smb/src/main/scala/com/spotify/scio/smb/SMBIO.scala b/scio-smb/src/main/scala/com/spotify/scio/smb/SMBIO.scala new file mode 100644 index 0000000000..33d15c55e5 --- /dev/null +++ b/scio-smb/src/main/scala/com/spotify/scio/smb/SMBIO.scala @@ -0,0 +1,8 @@ +package com.spotify.scio.smb + +import com.spotify.scio.io.{KeyedIO, TapOf, TapT, TestIO} + +final case class SMBIO[K, T](id: String, keyBy: T => K) extends KeyedIO[K, T] with TestIO[(K, T)] { + override val tapT: TapT.Aux[(K, T), (K, T)] = TapOf[(K, T)] + override def testId: String = s"SMBIO($id)" +} diff --git a/scio-smb/src/main/scala/com/spotify/scio/smb/syntax/SortMergeBucketScioContextSyntax.scala b/scio-smb/src/main/scala/com/spotify/scio/smb/syntax/SortMergeBucketScioContextSyntax.scala index 4120768c3a..9ff0677dbc 100644 --- a/scio-smb/src/main/scala/com/spotify/scio/smb/syntax/SortMergeBucketScioContextSyntax.scala +++ b/scio-smb/src/main/scala/com/spotify/scio/smb/syntax/SortMergeBucketScioContextSyntax.scala @@ -21,6 +21,8 @@ import com.spotify.scio.ScioContext import com.spotify.scio.annotations.experimental import com.spotify.scio.coders.Coder import com.spotify.scio.io.{ClosedTap, EmptyTap} +import com.spotify.scio.smb.SMBIO +import com.spotify.scio.testing.TestDataManager import com.spotify.scio.values._ import org.apache.beam.sdk.extensions.smb.SortedBucketIO.{AbsCoGbkTransform, Transformable} import org.apache.beam.sdk.extensions.smb.SortedBucketTransform.{BucketItem, MergedBucket} @@ -68,35 +70,44 @@ final class SortedBucketScioContext(@transient private val self: ScioContext) ex lhs: SortedBucketIO.Read[L], rhs: SortedBucketIO.Read[R], targetParallelism: TargetParallelism = TargetParallelism.auto() - ): SCollection[(K, (L, R))] = { - val t = SortedBucketIO.read(keyClass).of(lhs, rhs).withTargetParallelism(targetParallelism) - val (tupleTagA, tupleTagB) = (lhs.getTupleTag, rhs.getTupleTag) - val tfName = self.tfName - - self - .wrap(self.pipeline.apply(s"SMB CoGroupByKey@$tfName", t)) - .withName(tfName) - .applyTransform(ParDo.of(new DoFn[KV[K, CoGbkResult], (K, (L, R))] { - @ProcessElement - private[smb] def processElement( - @Element element: KV[K, CoGbkResult], - out: OutputReceiver[(K, (L, R))] - ): Unit = { - val cgbkResult = element.getValue - val (resA, resB) = (cgbkResult.getAll(tupleTagA), cgbkResult.getAll(tupleTagB)) - val itB = resB.iterator() - val key = element.getKey - - while (itB.hasNext) { - val b = itB.next() - val ai = resA.iterator() - while (ai.hasNext) { - val a = ai.next() - out.output((key, (a, b))) + ): SCollection[(K, (L, R))] = self.requireNotClosed { + if (self.isTest) { + val testInput = TestDataManager.getInput(self.testId.get) + val idLhs = lhs.getInputDirectories.asScala.mkString(",") + val testLhs = testInput[(K, L)](SMBIO(idLhs, null)).toSCollection(self) + val idRhs = rhs.getInputDirectories.asScala.mkString(",") + val testRhs = testInput[(K, R)](SMBIO(idRhs, null)).toSCollection(self) + testLhs.join(testRhs) + } else { + val t = SortedBucketIO.read(keyClass).of(lhs, rhs).withTargetParallelism(targetParallelism) + val (tupleTagA, tupleTagB) = (lhs.getTupleTag, rhs.getTupleTag) + val tfName = self.tfName + + self + .wrap(self.pipeline.apply(s"SMB CoGroupByKey@$tfName", t)) + .withName(tfName) + .applyTransform(ParDo.of(new DoFn[KV[K, CoGbkResult], (K, (L, R))] { + @ProcessElement + private[smb] def processElement( + @Element element: KV[K, CoGbkResult], + out: OutputReceiver[(K, (L, R))] + ): Unit = { + val cgbkResult = element.getValue + val (resA, resB) = (cgbkResult.getAll(tupleTagA), cgbkResult.getAll(tupleTagB)) + val itB = resB.iterator() + val key = element.getKey + + while (itB.hasNext) { + val b = itB.next() + val ai = resA.iterator() + while (ai.hasNext) { + val a = ai.next() + out.output((key, (a, b))) + } } } - } - })) + })) + } } /** Secondary keyed variant. */ diff --git a/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala b/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala index 3c7dc7e1cd..d489936273 100644 --- a/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala +++ b/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala @@ -25,6 +25,7 @@ import org.apache.beam.sdk.extensions.smb.SortedBucketSource.BucketedInput import org.apache.beam.sdk.io.FileSystems import org.apache.beam.sdk.io.fs.ResourceId import org.apache.beam.sdk.values.TupleTag +import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList import org.apache.hadoop.conf.Configuration import org.apache.parquet.filter2.predicate.FilterPredicate import org.apache.parquet.hadoop.metadata.CompressionCodecName @@ -80,6 +81,9 @@ object ParquetTypeSortedBucketIO { def withConfiguration(configuration: Configuration): Read[T] = this.copy(configuration = configuration) + override def getInputDirectories: ImmutableList[String] = + ImmutableList.copyOf(inputDirectories.asJava) + override def getFilenameSuffix: String = filenameSuffix override def getTupleTag: TupleTag[T] = tupleTag override protected def toBucketedInput( diff --git a/scio-smb/src/test/scala/com/spotify/scio/smb/SortMergeBucketTest.scala b/scio-smb/src/test/scala/com/spotify/scio/smb/SortMergeBucketTest.scala new file mode 100644 index 0000000000..8d846f2e27 --- /dev/null +++ b/scio-smb/src/test/scala/com/spotify/scio/smb/SortMergeBucketTest.scala @@ -0,0 +1,56 @@ +package com.spotify.scio.smb + +import com.spotify.scio.ContextAndArgs +import com.spotify.scio.avro.{Account, Address, User} +import com.spotify.scio.io.{CustomIO, TextIO} +import com.spotify.scio.testing.PipelineSpec +import org.apache.beam.sdk.extensions.smb.{AvroSortedBucketIO, TargetParallelism} +import org.apache.beam.sdk.values.TupleTag + +import java.util.Collections + +object SmbJob { + + def main(cmdlineArgs: Array[String]) = { + val (sc, args) = ContextAndArgs(cmdlineArgs) + + sc.sortMergeJoin( + classOf[Integer], + AvroSortedBucketIO + .read(new TupleTag[User]("lhs"), classOf[User]) + .from(args("users")), + AvroSortedBucketIO + .read(new TupleTag[Account]("rhs"), classOf[Account]) + .from(args("accounts")), + TargetParallelism.max() + ).values + .map { case (u, a) => s"${u.getLastName}=${a.getAmount}" } + .saveAsTextFile(args("output")) + + sc.run().waitUntilDone() + } + +} + +class SortMergeBucketTest extends PipelineSpec { + + "SMB" should "be able to mock input and output" in { + val account: Account = new Account(1, "type", "name", 12.5, null) + val address = new Address("street1", "street2", "city", "state", "01234", "Sweden") + val user = + new User(1, "lastname", "firstname", "email@foobar.com", Collections.emptyList(), address) + + JobTest[SmbJob.type] + .args( + "--users=users", + "--accounts=accounts", + "--output=output" + ) + .keyedInput(SMBIO[Integer, User]("users", _.getId), Seq(user)) + // input is also possible but error prone as key must be given manually + .input(SMBIO[Integer, Account]("accounts", _.getId), Seq(account.getId -> account)) + .output(TextIO("output"))(_ should containInAnyOrder(Seq("lastname=12.5"))) + .run() + } + +} diff --git a/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala b/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala index a3efce8875..2efde16008 100644 --- a/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala +++ b/scio-test/src/main/scala/com/spotify/scio/testing/JobTest.scala @@ -19,7 +19,7 @@ package com.spotify.scio.testing import java.lang.reflect.InvocationTargetException import com.spotify.scio.ScioResult -import com.spotify.scio.io.ScioIO +import com.spotify.scio.io.{KeyedIO, ScioIO} import com.spotify.scio.util.ScioUtil import com.spotify.scio.values.SCollection import com.spotify.scio.coders.Coder @@ -115,6 +115,9 @@ object JobTest { def input[T: Coder](io: ScioIO[T], value: Iterable[T]): Builder = input(io, IterableInputSource(value)) + def keyedInput[K: Coder, T: Coder](io: KeyedIO[K, T], value: Iterable[T]): Builder = + input(io, IterableInputSource(value.map(x => io.keyBy(x) -> x))) + /** * Feed an input in the form of a `PTransform[PBegin, PCollection[T]]` to the pipeline being * tested. Note that `PTransform` inputs may not be supported for all `TestIO[T]` types. @@ -122,6 +125,9 @@ object JobTest { def inputStream[T: Coder](io: ScioIO[T], stream: TestStream[T]): Builder = input(io, TestStreamInputSource(stream)) + def keyedInputStream[K: Coder, T: Coder](io: KeyedIO[K, T], stream: Iterable[T]): Builder = + input(io, IterableInputSource(stream.map(x => io.keyBy(x) -> x))) + private def input[T](io: ScioIO[T], value: JobInputSource[T]): Builder = { require(!state.input.contains(io.testId), "Duplicate test input: " + io.testId) state = state.copy(input = state.input + (io.testId -> value))