Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make Sparkey testable #5128

Merged
merged 13 commits into from
Jan 11, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,10 @@ package com.spotify.scio.avro.types

import com.spotify.scio.avro.types.Schemas._
import com.spotify.scio.avro.types.Schemas.FieldMode._
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.beam.model.pipeline.v1.SchemaApi.SchemaOrBuilder
import org.apache.avro.SchemaBuilder
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

import scala.jdk.CollectionConverters._

class SchemaUtilTest extends AnyFlatSpec with Matchers {
"toPrettyString()" should "support primitive types" in {
SchemaUtil.toPrettyString1(parseSchema(s"${basicFields()}")) shouldBe
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,7 @@ trait SideInput[T] extends Serializable {
/**
* Create a new [[SideInput]] by applying a function on the elements wrapped in this SideInput.
*/
@deprecated(since = "0.14.0")
def map[B](f: T => B): SideInput[B] = new DelegatingSideInput[T, B](this, f)

private[scio] val view: PCollectionView[_]
Expand Down Expand Up @@ -127,6 +128,7 @@ private[values] class MultiMapSideInput[K, V](val view: PCollectionView[JMap[K,
JMapWrapper.ofMultiMap(context.sideInput(view))
}

@deprecated(since = "0.14.0")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

can we include a deprecation message advising users what methods to use instead?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No this is just a pure "don't do this"

private[values] class DelegatingSideInput[A, B](val si: SideInput[A], val mapper: A => B)
extends SideInput[B] {

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -47,9 +47,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*/
def largeHashJoin[W](
rhs: SCollection[(K, W)],
numShards: Short = DefaultSideInputNumShards,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
numShards: Short = SparkeyIO.DefaultSideInputNumShards,
compressionType: CompressionType = SparkeyIO.DefaultCompressionType,
compressionBlockSize: Int = SparkeyIO.DefaultCompressionBlockSize
): SCollection[(K, (V, W))] = {
implicit val wCoder: Coder[W] = rhs.valueCoder
largeHashJoin(rhs.asLargeMultiMapSideInput(numShards, compressionType, compressionBlockSize))
Expand Down Expand Up @@ -97,9 +97,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*/
def largeHashLeftOuterJoin[W](
rhs: SCollection[(K, W)],
numShards: Short = DefaultSideInputNumShards,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
numShards: Short = SparkeyIO.DefaultSideInputNumShards,
compressionType: CompressionType = SparkeyIO.DefaultCompressionType,
compressionBlockSize: Int = SparkeyIO.DefaultCompressionBlockSize
): SCollection[(K, (V, Option[W]))] = {
implicit val wCoder: Coder[W] = rhs.valueCoder
largeHashLeftOuterJoin(
Expand Down Expand Up @@ -141,9 +141,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*/
def largeHashFullOuterJoin[W](
rhs: SCollection[(K, W)],
numShards: Short = DefaultSideInputNumShards,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
numShards: Short = SparkeyIO.DefaultSideInputNumShards,
compressionType: CompressionType = SparkeyIO.DefaultCompressionType,
compressionBlockSize: Int = SparkeyIO.DefaultCompressionBlockSize
): SCollection[(K, (Option[V], Option[W]))] = {
implicit val wCoder = rhs.valueCoder
largeHashFullOuterJoin(
Expand Down Expand Up @@ -206,9 +206,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*/
def largeHashIntersectByKey(
rhs: SCollection[K],
numShards: Short = DefaultSideInputNumShards,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
numShards: Short = SparkeyIO.DefaultSideInputNumShards,
compressionType: CompressionType = SparkeyIO.DefaultCompressionType,
compressionBlockSize: Int = SparkeyIO.DefaultCompressionBlockSize
): SCollection[(K, V)] =
largeHashIntersectByKey(
rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize)
Expand Down Expand Up @@ -238,9 +238,9 @@ class PairLargeHashSCollectionFunctions[K, V](private val self: SCollection[(K,
*/
def largeHashSubtractByKey(
rhs: SCollection[K],
numShards: Short = DefaultSideInputNumShards,
compressionType: CompressionType = DefaultCompressionType,
compressionBlockSize: Int = DefaultCompressionBlockSize
numShards: Short = SparkeyIO.DefaultSideInputNumShards,
compressionType: CompressionType = SparkeyIO.DefaultCompressionType,
compressionBlockSize: Int = SparkeyIO.DefaultCompressionBlockSize
): SCollection[(K, V)] =
largeHashSubtractByKey(
rhs.asLargeSetSideInput(numShards, compressionType, compressionBlockSize)
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,205 @@
/*
* Copyright 2023 Spotify AB.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/

package com.spotify.scio.extra.sparkey

import java.lang.Math.floorMod
import java.util.UUID
import com.spotify.scio.coders.{BeamCoders, Coder}
import com.spotify.scio.extra.sparkey.instances._
import com.spotify.scio.io.{TapOf, TapT, TestIO}
import com.spotify.scio.util.RemoteFileUtil
import com.spotify.scio.values.SCollection
import com.spotify.sparkey.{CompressionType, SparkeyReader}
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.MoveOptions.StandardMoveOptions
import org.apache.beam.sdk.io.fs.{EmptyMatchTreatment, ResourceId}
import org.slf4j.LoggerFactory

import scala.collection.mutable
import scala.jdk.CollectionConverters._

/** Special version of [[com.spotify.scio.io.ScioIO]] for use with sparkey methods. */
case class SparkeyTestIO[T] private (path: String) extends TestIO[T] {
RustedBones marked this conversation as resolved.
Show resolved Hide resolved
override val tapT: TapT.Aux[T, T] = TapOf[T]
}

object SparkeyIO {
@transient private lazy val logger = LoggerFactory.getLogger(this.getClass)

val DefaultNumShards: Short = 1
val DefaultSideInputNumShards: Short = 64
val DefaultCompressionType: CompressionType = CompressionType.NONE
val DefaultCompressionBlockSize: Int = 0

def apply(path: String): SparkeyTestIO[SparkeyReader] = SparkeyTestIO[SparkeyReader](path)
def output[K, V](path: String): SparkeyTestIO[(K, V)] = SparkeyTestIO[(K, V)](path)
RustedBones marked this conversation as resolved.
Show resolved Hide resolved

private def writeToSparkey[K, V](
uri: SparkeyUri,
rfu: RemoteFileUtil,
maxMemoryUsage: Long,
compressionType: CompressionType,
compressionBlockSize: Int,
elements: Iterable[(K, V)],
w: SparkeyWritable[K, V]
): SparkeyUri = {
val writer =
new SparkeyWriter(uri, rfu, compressionType, compressionBlockSize, maxMemoryUsage)
val it = elements.iterator
while (it.hasNext) {
val kv = it.next()
w.put(writer, kv._1, kv._2)
}
writer.close()
uri
}

/** @param baseUri The final destination for sparkey files */
private[sparkey] def writeSparkey[K, V](
baseUri: SparkeyUri,
writable: SparkeyWritable[K, V],
data: SCollection[(K, V)],
maxMemoryUsage: Long,
numShards: Short,
compressionType: CompressionType,
compressionBlockSize: Int
): SCollection[SparkeyUri] = {
require(
!baseUri.isSharded,
s"path to which sparkey will be saved must not include a `*` wildcard."
)
require(numShards > 0, s"numShards must be greater than 0, found $numShards")
if (compressionType != CompressionType.NONE) {
require(
compressionBlockSize > 0,
s"Compression block size must be > 0 for $compressionType"
)
}
val isUnsharded = numShards == 1
val rfu = RemoteFileUtil.create(data.context.options)
val tempLocation = data.context.options.getTempLocation

// verify that we're not writing to a previously-used output dir
List(baseUri, SparkeyUri(s"${baseUri.path}/*")).foreach { uri =>
require(!uri.exists(rfu), s"Sparkey URI ${uri.path} already exists")
}
// root destination to which all _interim_ results are written,
// deleted upon successful completion of the write
val tempPath = s"$tempLocation/sparkey-temp-${UUID.randomUUID}"

val outputUri = if (isUnsharded) baseUri else SparkeyUri(s"${baseUri.path}/*")
logger.info(s"Saving as Sparkey with $numShards shards: ${baseUri.path}")

implicit val coder: Coder[(K, V)] = BeamCoders.getCoder(data)
implicit val keyCoder: Coder[K] = BeamCoders.getKeyCoder(data)
implicit val valueCoder: Coder[V] = BeamCoders.getValueCoder(data)

def resourcesForPattern(pattern: String): mutable.Buffer[ResourceId] =
FileSystems
.`match`(pattern, EmptyMatchTreatment.ALLOW)
.metadata()
.asScala
.map(_.resourceId())

data.transform { collection =>
// shard by key hash
val shards = collection
.groupBy { case (k, _) => floorMod(writable.shardHash(k), numShards.toInt).toShort }

// gather shards that actually have values
val shardsWithKeys = shards.keys.asSetSingletonSideInput
// fill in missing shards
val missingShards = shards.context
.parallelize((0 until numShards.toInt).map(_.toShort))
.withSideInputs(shardsWithKeys)
.flatMap { case (shard, ctx) =>
val shardExists = ctx(shardsWithKeys).contains(shard)
if (shardExists) None else Some(shard -> Iterable.empty[(K, V)])
}
.toSCollection

// write files to temporary locations
val tempShardUris = shards
.union(missingShards)
.map { case (shard, xs) =>
// use a temp uri so that if a bundle fails retries will not fail
val tempUri = SparkeyUri(s"$tempPath/${UUID.randomUUID}")
// perform the write to the temp uri
shard -> writeToSparkey(
tempUri.sparkeyUriForShard(shard, numShards),
rfu,
maxMemoryUsage,
compressionType,
compressionBlockSize,
xs,
writable
)
}

// TODO WriteFiles inserts a reshuffle here for unclear reasons

tempShardUris.reifyAsListInGlobalWindow
.map { seq =>
val items = seq.toList

// accumulate source files and destination files
val (srcPaths, dstPaths) = items
.foldLeft((List.empty[ResourceId], List.empty[ResourceId])) {
case ((srcs, dsts), (shard, uri)) =>
if (isUnsharded && shard != 0)
throw new IllegalArgumentException(s"numShards=1 but got shard=$shard")
// assumes paths always returns things in the same order 🙃
val dstUri =
if (isUnsharded) baseUri else baseUri.sparkeyUriForShard(shard, numShards)

val srcResources = srcs ++ uri.paths
val dstResources = dsts ++ dstUri.paths

(srcResources, dstResources)
}

// rename source files to dest files
logger.info(s"Copying ${items.size} files from temp to final GCS destination.")
// per FileBasedSink.java#783 ignore errors as files may have previously been deleted
FileSystems.rename(
srcPaths.asJava,
dstPaths.asJava,
StandardMoveOptions.IGNORE_MISSING_FILES,
StandardMoveOptions.SKIP_IF_DESTINATION_EXISTS
)

// cleanup orphan files per FileBasedSink.removeTemporaryFiles
val orphanTempFiles = resourcesForPattern(s"${tempPath}/*")
orphanTempFiles.foreach { r =>
logger.warn("Will also remove unknown temporary file {}.", r)
}
FileSystems.delete(orphanTempFiles.asJava, StandardMoveOptions.IGNORE_MISSING_FILES)
// clean up temp dir, can fail, but failure is to be ignored per FileBasedSink
val tempPathResource = resourcesForPattern(tempPath)
try {
FileSystems.delete(tempPathResource.asJava, StandardMoveOptions.IGNORE_MISSING_FILES)
} catch {
case _: Exception =>
logger.warn("Failed to remove temporary directory: [{}].", tempPath)
}

outputUri
}
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,24 @@ import com.spotify.sparkey.extra.ThreadLocalSparkeyReader
import com.spotify.sparkey.SparkeyReader
import org.apache.beam.sdk.io.FileSystems
import org.apache.beam.sdk.io.fs.{EmptyMatchTreatment, MatchResult, ResourceId}
import org.apache.beam.sdk.options.PipelineOptions

import java.nio.file.Path
import java.util.UUID
import scala.collection.mutable
import scala.jdk.CollectionConverters._

case class InvalidNumShardsException(str: String) extends RuntimeException(str)

object SparkeyUri {
def extensions: Seq[String] = Seq(".spi", ".spl")

def baseUri(optPath: Option[String], opts: PipelineOptions): SparkeyUri = {
val tempLocation = opts.getTempLocation
// the final destination for sparkey files. A temp dir if not permanently persisted.
val basePath = optPath.getOrElse(s"$tempLocation/sparkey-${UUID.randomUUID}")
SparkeyUri(basePath)
}
}

/**
Expand All @@ -45,7 +54,7 @@ object SparkeyUri {
*/
case class SparkeyUri(path: String) {
private[sparkey] val isLocal = ScioUtil.isLocalUri(new URI(path))
private val isSharded = path.endsWith("*")
private[sparkey] val isSharded = path.endsWith("*")
private[sparkey] val basePath =
if (!isSharded) path else path.split("/").dropRight(1).mkString("/")

Expand Down
Loading