Skip to content

Commit

Permalink
TestIO proposal for SMB
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones committed Aug 28, 2023
1 parent 2a9c041 commit 434d558
Show file tree
Hide file tree
Showing 12 changed files with 125 additions and 47 deletions.
2 changes: 1 addition & 1 deletion build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -1322,7 +1322,7 @@ lazy val `scio-smb`: Project = project
.in(file("scio-smb"))
.dependsOn(
`scio-core`,
`scio-test` % "test;it",
`scio-test` % "test->test;it",
`scio-avro` % IntegrationTest
)
.configs(IntegrationTest)
Expand Down
4 changes: 4 additions & 0 deletions scio-core/src/main/scala/com/spotify/scio/io/ScioIO.scala
Original file line number Diff line number Diff line change
Expand Up @@ -159,6 +159,10 @@ trait TestIO[T] extends ScioIO[T] {
throw new UnsupportedOperationException(s"$this is for testing purpose only")
}

trait KeyedIO[K, T] extends ScioIO[(K, T)] {
def keyBy: T => K
}

/**
* Special version of [[ScioIO]] for use with [[ScioContext.customInput]] and
* [[SCollection.saveAsCustomOutput]].
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -192,11 +192,6 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

@Nullable
abstract Schema getSchema();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -108,10 +108,6 @@ public static <K1, K2> TransformOutput<K1, K2> transformOutput(
*/
@AutoValue
public abstract static class Read extends SortedBucketIO.Read<TableRow> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

abstract Compression getCompression();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -200,11 +200,6 @@ public static <K1, K2, T extends SpecificRecord> TransformOutput<K1, K2, T> tran
/** Reads from Avro sorted-bucket files, to be used with {@link SortedBucketIO.CoGbk}. */
@AutoValue
public abstract static class Read<T extends IndexedRecord> extends SortedBucketIO.Read<T> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

@Nullable
abstract Schema getSchema();

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -493,6 +493,13 @@ public abstract static class TransformOutput<K1, K2, V> implements Serializable

/** Represents a single sorted-bucket source written using {@link SortedBucketSink}. */
public abstract static class Read<V> implements Serializable {

@Nullable
public abstract ImmutableList<String> getInputDirectories();


abstract String getFilenameSuffix();

public abstract TupleTag<V> getTupleTag();

protected abstract BucketedInput<V> toBucketedInput(SortedBucketSource.Keying keying);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -125,10 +125,6 @@ public static <K1, K2> TransformOutput<K1, K2> transformOutput(
*/
@AutoValue
public abstract static class Read extends SortedBucketIO.Read<Example> {
@Nullable
abstract ImmutableList<String> getInputDirectories();

abstract String getFilenameSuffix();

abstract Compression getCompression();

Expand Down
8 changes: 8 additions & 0 deletions scio-smb/src/main/scala/com/spotify/scio/smb/SMBIO.scala
Original file line number Diff line number Diff line change
@@ -0,0 +1,8 @@
package com.spotify.scio.smb

import com.spotify.scio.io.{KeyedIO, TapOf, TapT, TestIO}

final case class SMBIO[K, T](id: String, keyBy: T => K) extends KeyedIO[K, T] with TestIO[(K, T)] {
override val tapT: TapT.Aux[(K, T), (K, T)] = TapOf[(K, T)]
override def testId: String = s"SMBIO($id)"
}
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ import com.spotify.scio.ScioContext
import com.spotify.scio.annotations.experimental
import com.spotify.scio.coders.Coder
import com.spotify.scio.io.{ClosedTap, EmptyTap}
import com.spotify.scio.smb.SMBIO
import com.spotify.scio.testing.TestDataManager
import com.spotify.scio.values._
import org.apache.beam.sdk.extensions.smb.SortedBucketIO.{AbsCoGbkTransform, Transformable}
import org.apache.beam.sdk.extensions.smb.SortedBucketTransform.{BucketItem, MergedBucket}
Expand Down Expand Up @@ -68,35 +70,44 @@ final class SortedBucketScioContext(@transient private val self: ScioContext) ex
lhs: SortedBucketIO.Read[L],
rhs: SortedBucketIO.Read[R],
targetParallelism: TargetParallelism = TargetParallelism.auto()
): SCollection[(K, (L, R))] = {
val t = SortedBucketIO.read(keyClass).of(lhs, rhs).withTargetParallelism(targetParallelism)
val (tupleTagA, tupleTagB) = (lhs.getTupleTag, rhs.getTupleTag)
val tfName = self.tfName

self
.wrap(self.pipeline.apply(s"SMB CoGroupByKey@$tfName", t))
.withName(tfName)
.applyTransform(ParDo.of(new DoFn[KV[K, CoGbkResult], (K, (L, R))] {
@ProcessElement
private[smb] def processElement(
@Element element: KV[K, CoGbkResult],
out: OutputReceiver[(K, (L, R))]
): Unit = {
val cgbkResult = element.getValue
val (resA, resB) = (cgbkResult.getAll(tupleTagA), cgbkResult.getAll(tupleTagB))
val itB = resB.iterator()
val key = element.getKey

while (itB.hasNext) {
val b = itB.next()
val ai = resA.iterator()
while (ai.hasNext) {
val a = ai.next()
out.output((key, (a, b)))
): SCollection[(K, (L, R))] = self.requireNotClosed {
if (self.isTest) {
val testInput = TestDataManager.getInput(self.testId.get)
val idLhs = lhs.getInputDirectories.asScala.mkString(",")
val testLhs = testInput[(K, L)](SMBIO(idLhs, null)).toSCollection(self)
val idRhs = rhs.getInputDirectories.asScala.mkString(",")
val testRhs = testInput[(K, R)](SMBIO(idRhs, null)).toSCollection(self)
testLhs.join(testRhs)
} else {
val t = SortedBucketIO.read(keyClass).of(lhs, rhs).withTargetParallelism(targetParallelism)
val (tupleTagA, tupleTagB) = (lhs.getTupleTag, rhs.getTupleTag)
val tfName = self.tfName

self
.wrap(self.pipeline.apply(s"SMB CoGroupByKey@$tfName", t))
.withName(tfName)
.applyTransform(ParDo.of(new DoFn[KV[K, CoGbkResult], (K, (L, R))] {
@ProcessElement
private[smb] def processElement(
@Element element: KV[K, CoGbkResult],
out: OutputReceiver[(K, (L, R))]
): Unit = {
val cgbkResult = element.getValue
val (resA, resB) = (cgbkResult.getAll(tupleTagA), cgbkResult.getAll(tupleTagB))
val itB = resB.iterator()
val key = element.getKey

while (itB.hasNext) {
val b = itB.next()
val ai = resA.iterator()
while (ai.hasNext) {
val a = ai.next()
out.output((key, (a, b)))
}
}
}
}
}))
}))
}
}

/** Secondary keyed variant. */
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import org.apache.beam.sdk.extensions.smb.SortedBucketSource.BucketedInput
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.ResourceId
import org.apache.beam.sdk.values.TupleTag
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList
import org.apache.hadoop.conf.Configuration
import org.apache.parquet.filter2.predicate.FilterPredicate
import org.apache.parquet.hadoop.metadata.CompressionCodecName
Expand Down Expand Up @@ -80,6 +81,9 @@ object ParquetTypeSortedBucketIO {
def withConfiguration(configuration: Configuration): Read[T] =
this.copy(configuration = configuration)

override def getInputDirectories: ImmutableList[String] =
ImmutableList.copyOf(inputDirectories.asJava)
override def getFilenameSuffix: String = filenameSuffix
override def getTupleTag: TupleTag[T] = tupleTag

override protected def toBucketedInput(
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,56 @@
package com.spotify.scio.smb

import com.spotify.scio.ContextAndArgs
import com.spotify.scio.avro.{Account, Address, User}
import com.spotify.scio.io.{CustomIO, TextIO}
import com.spotify.scio.testing.PipelineSpec
import org.apache.beam.sdk.extensions.smb.{AvroSortedBucketIO, TargetParallelism}
import org.apache.beam.sdk.values.TupleTag

import java.util.Collections

object SmbJob {

def main(cmdlineArgs: Array[String]) = {
val (sc, args) = ContextAndArgs(cmdlineArgs)

sc.sortMergeJoin(
classOf[Integer],
AvroSortedBucketIO
.read(new TupleTag[User]("lhs"), classOf[User])
.from(args("users")),
AvroSortedBucketIO
.read(new TupleTag[Account]("rhs"), classOf[Account])
.from(args("accounts")),
TargetParallelism.max()
).values
.map { case (u, a) => s"${u.getLastName}=${a.getAmount}" }
.saveAsTextFile(args("output"))

sc.run().waitUntilDone()
}

}

class SortMergeBucketTest extends PipelineSpec {

"SMB" should "be able to mock input and output" in {
val account: Account = new Account(1, "type", "name", 12.5, null)
val address = new Address("street1", "street2", "city", "state", "01234", "Sweden")
val user =
new User(1, "lastname", "firstname", "[email protected]", Collections.emptyList(), address)

JobTest[SmbJob.type]
.args(
"--users=users",
"--accounts=accounts",
"--output=output"
)
.keyedInput(SMBIO[Integer, User]("users", _.getId), Seq(user))
// input is also possible but error prone as key must be given manually
.input(SMBIO[Integer, Account]("accounts", _.getId), Seq(account.getId -> account))
.output(TextIO("output"))(_ should containInAnyOrder(Seq("lastname=12.5")))
.run()
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.spotify.scio.testing

import java.lang.reflect.InvocationTargetException
import com.spotify.scio.ScioResult
import com.spotify.scio.io.ScioIO
import com.spotify.scio.io.{KeyedIO, ScioIO}
import com.spotify.scio.util.ScioUtil
import com.spotify.scio.values.SCollection
import com.spotify.scio.coders.Coder
Expand Down Expand Up @@ -115,13 +115,19 @@ object JobTest {
def input[T: Coder](io: ScioIO[T], value: Iterable[T]): Builder =
input(io, IterableInputSource(value))

def keyedInput[K: Coder, T: Coder](io: KeyedIO[K, T], value: Iterable[T]): Builder =
input(io, IterableInputSource(value.map(x => io.keyBy(x) -> x)))

/**
* Feed an input in the form of a `PTransform[PBegin, PCollection[T]]` to the pipeline being
* tested. Note that `PTransform` inputs may not be supported for all `TestIO[T]` types.
*/
def inputStream[T: Coder](io: ScioIO[T], stream: TestStream[T]): Builder =
input(io, TestStreamInputSource(stream))

def keyedInputStream[K: Coder, T: Coder](io: KeyedIO[K, T], stream: Iterable[T]): Builder =
input(io, IterableInputSource(stream.map(x => io.keyBy(x) -> x)))

private def input[T](io: ScioIO[T], value: JobInputSource[T]): Builder = {
require(!state.input.contains(io.testId), "Duplicate test input: " + io.testId)
state = state.copy(input = state.input + (io.testId -> value))
Expand Down

0 comments on commit 434d558

Please sign in to comment.