Skip to content

Commit

Permalink
Use avro builder API (#5119)
Browse files Browse the repository at this point in the history
  • Loading branch information
RustedBones authored Dec 14, 2023
1 parent 45c1bd3 commit acfb416
Show file tree
Hide file tree
Showing 17 changed files with 255 additions and 311 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ import com.spotify.scio.bigquery.BigQueryTypedTable.Format
import com.spotify.scio.bigquery.client.BigQuery
import com.spotify.scio.testing._
import magnolify.scalacheck.auto._
import org.apache.avro.{LogicalTypes, Schema}
import org.apache.avro.{LogicalTypes, Schema, SchemaBuilder}
import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder}
import org.apache.beam.sdk.options.PipelineOptionsFactory
import org.joda.time.{Instant, LocalDate, LocalDateTime, LocalTime}
Expand Down Expand Up @@ -166,28 +166,17 @@ class TypedBigQueryIT extends PipelineSpec with BeforeAndAfterAll {

it should "write GenericRecord records with logical types" in {
val sc = ScioContext(options)
import scala.jdk.CollectionConverters._
val schema: Schema = Schema.createRecord(
"Record",
"",
"com.spotify.scio.bigquery",
false,
List(
new Schema.Field(
"date",
LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT)),
"",
0
),
new Schema.Field(
"time",
LogicalTypes.timeMicros().addToSchema(Schema.create(Schema.Type.LONG)),
"",
0L
),
new Schema.Field("datetime", Schema.create(Schema.Type.STRING), "", "")
).asJava
)
// format: off
val schema: Schema = SchemaBuilder
.record("Record")
.namespace("com.spotify.scio.bigquery")
.fields()
.name("date").`type`(LogicalTypes.date().addToSchema(Schema.create(Schema.Type.INT))).withDefault(0)
.name("time").`type`(LogicalTypes.timeMicros().addToSchema(Schema.create(Schema.Type.LONG))).withDefault(0L)
.name("datetime").`type`().stringType().stringDefault("")
.endRecord()
// format: on

implicit val coder = avroGenericRecordCoder(schema)
val ltRecords: Seq[GenericRecord] =
Seq(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ import com.spotify.scio.util.MultiJoin
import com.spotify.scio.values.SCollection
import org.apache.avro.Schema
import org.apache.avro.Schema.Field
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder}
import org.apache.beam.sdk.extensions.smb.{AvroSortedBucketIO, SortedBucketIO, TargetParallelism}
import org.apache.beam.sdk.values.TupleTag
import org.apache.commons.io.FileUtils
Expand Down Expand Up @@ -213,10 +213,10 @@ class SortMergeBucketParityIT extends AnyFlatSpec with Matchers {

val outputPaths = (0 until numSources).map { n =>
val data = (0 to Random.nextInt(100)).map { i =>
val gr: GenericRecord = new GenericData.Record(schema)
gr.put("key", i)
gr.put("value", s"v$i")
gr
new GenericRecordBuilder(schema)
.set("key", i)
.set("value", s"v$i")
.build()
}

val outputPath = new File(tempFolder, s"source$n")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ package com.spotify.scio.avro.dynamic.syntax

import com.google.protobuf.Message
import com.spotify.scio.avro.AvroIO
import com.spotify.scio.coders.{AvroBytesUtil, Coder, CoderMaterializer}
import com.spotify.scio.coders.{AvroBytesUtil, Coder}
import com.spotify.scio.io.{ClosedTap, EmptyTap}
import com.spotify.scio.io.dynamic.syntax.DynamicSCollectionOps
import com.spotify.scio.protobuf.util.ProtobufUtil
Expand Down Expand Up @@ -142,22 +142,19 @@ final class DynamicProtobufSCollectionOps[T <: Message](private val self: SColle
tempDirectory: String = AvroIO.WriteParam.DefaultTempDirectory,
prefix: String = AvroIO.WriteParam.DefaultPrefix
)(destinationFn: T => String)(implicit ct: ClassTag[T]): ClosedTap[Nothing] = {
val protoCoder = Coder.protoMessageCoder[T]
val elemCoder = CoderMaterializer.beam(self.context, protoCoder)
val avroSchema = AvroBytesUtil.schema
val nm = new JHashMap[String, AnyRef]()
nm.putAll((metadata ++ ProtobufUtil.schemaMetadataOf(ct)).asJava)

if (self.context.isTest) {
throw new NotImplementedError(
"Protobuf file with dynamic destinations cannot be used in a test context"
)
} else {
implicit val protoCoder: Coder[T] = Coder.protoMessageCoder[T]
val nm = new JHashMap[String, AnyRef]()
nm.putAll((metadata ++ ProtobufUtil.schemaMetadataOf(ct)).asJava)

val sink = BAvroIO
.sinkViaGenericRecords(
avroSchema,
(element: T, _: Schema) => AvroBytesUtil.encode(elemCoder, element)
)
.sink(AvroBytesUtil.schema)
.asInstanceOf[BAvroIO.Sink[T]]
.withDatumWriterFactory(AvroBytesUtil.datumWriterFactory)
.withCodec(codec)
.withMetadata(nm)
val write = writeDynamic(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,41 +17,47 @@

package com.spotify.scio.coders

import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.{Schema => ASchema}
import org.apache.avro.generic.{GenericDatumWriter, GenericRecord, GenericRecordBuilder}
import org.apache.avro.io.{DatumWriter, Encoder}
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.beam.sdk.coders.{Coder => BCoder}
import org.apache.beam.sdk.extensions.avro.io.AvroSink.DatumWriterFactory
import org.apache.beam.sdk.util.CoderUtils

import java.nio.ByteBuffer
import scala.jdk.CollectionConverters._

private[scio] object AvroBytesUtil {
val schema: ASchema = {
val s = ASchema.createRecord("AvroBytesRecord", null, null, false)
s.setFields(
List(
new ASchema.Field(
"bytes",
ASchema.create(ASchema.Type.BYTES),
null,
null.asInstanceOf[Object]
)
).asJava
)
s
val schema: Schema = SchemaBuilder
.record("AvroBytesRecord")
.fields()
.requiredBytes("bytes")
.endRecord()

private val byteField = schema.getField("bytes")

def datumWriterFactory[T: Coder]: DatumWriterFactory[T] = {
val bCoder = CoderMaterializer.beamWithDefault(Coder[T])
(schema: Schema) =>
new DatumWriter[T] {
private val underlying = new GenericDatumWriter[GenericRecord](schema)

override def setSchema(schema: Schema): Unit = underlying.setSchema(schema)

override def write(datum: T, out: Encoder): Unit =
underlying.write(AvroBytesUtil.encode(bCoder, datum), out)
}
}

def encode[T](coder: BCoder[T], obj: T): GenericRecord = {
val bytes = CoderUtils.encodeToByteArray(coder, obj)
val record = new GenericData.Record(schema)
record.put("bytes", ByteBuffer.wrap(bytes))
record
new GenericRecordBuilder(schema)
.set(byteField, ByteBuffer.wrap(bytes))
.build()
}

def decode[T](coder: BCoder[T], record: GenericRecord): T = {
val bb = record.get("bytes").asInstanceOf[ByteBuffer]
val bytes =
java.util.Arrays.copyOfRange(bb.array(), bb.position(), bb.limit())
val bb = record.get(byteField.pos()).asInstanceOf[ByteBuffer]
val bytes = java.util.Arrays.copyOfRange(bb.array(), bb.position(), bb.limit())
CoderUtils.decodeFromByteArray(coder, bytes)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,8 @@ package com.spotify.scio.avro.types

import com.spotify.scio.avro.types.Schemas._
import com.spotify.scio.avro.types.Schemas.FieldMode._
import org.apache.avro.Schema
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.beam.model.pipeline.v1.SchemaApi.SchemaOrBuilder
import org.scalatest.flatspec.AnyFlatSpec
import org.scalatest.matchers.should.Matchers

Expand Down Expand Up @@ -149,16 +150,11 @@ class SchemaUtilTest extends AnyFlatSpec with Matchers {
val expectedFields = SchemaUtil.scalaReservedWords
.map(e => s"`$e`")
.mkString(start = "", sep = ": Long, ", end = ": Long")
val schema =
Schema.createRecord(
"Row",
null,
null,
false,
SchemaUtil.scalaReservedWords.map { name =>
new Schema.Field(name, Schema.create(Schema.Type.LONG), null, null.asInstanceOf[Any])
}.asJava
)

val schema = SchemaUtil.scalaReservedWords
.foldLeft(SchemaBuilder.record("Row").fields())(_.requiredLong(_))
.endRecord()

SchemaUtil.toPrettyString1(schema) shouldBe s"case class Row($expectedFields)"
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,8 @@ import com.spotify.scio._
import com.spotify.scio.avro._
import com.spotify.scio.avro.types.AvroType
import com.spotify.scio.io.ClosedTap
import org.apache.avro.Schema
import org.apache.avro.generic.{GenericData, GenericRecord}

import scala.jdk.CollectionConverters._
import org.apache.avro.{Schema, SchemaBuilder}
import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder}

object AvroExample {
@AvroType.fromSchema("""{
Expand Down Expand Up @@ -107,12 +105,12 @@ object AvroExample {
implicit def genericCoder = avroGenericRecordCoder(schema)
sc.parallelize(1 to 100)
.map[GenericRecord] { i =>
val r = new GenericData.Record(schema)
r.put("id", i)
r.put("amount", i.toDouble)
r.put("name", "account" + i)
r.put("type", "checking")
r
new GenericRecordBuilder(schema)
.set("id", i)
.set("amount", i.toDouble)
.set("name", "account" + i)
.set("type", "checking")
.build()
}
.saveAsAvroFile(args("output"), schema = schema)
}
Expand All @@ -133,24 +131,12 @@ object AvroExample {
.map(_.toString)
.saveAsTextFile(args("output"))

val schema: Schema = {
def f(name: String, tpe: Schema.Type) =
new Schema.Field(
name,
Schema.createUnion(List(Schema.create(Schema.Type.NULL), Schema.create(tpe)).asJava),
null: String,
null: AnyRef
)

val s = Schema.createRecord("GenericAccountRecord", null, null, false)
s.setFields(
List(
f("id", Schema.Type.INT),
f("amount", Schema.Type.DOUBLE),
f("name", Schema.Type.STRING),
f("type", Schema.Type.STRING)
).asJava
)
s
}
val schema: Schema = SchemaBuilder
.record("GenericAccountRecord")
.fields()
.requiredInt("id")
.requiredDouble("amount")
.requiredString("name")
.requiredString("type")
.endRecord()
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@ import com.spotify.scio.coders.Coder
import com.spotify.scio.values.SCollection
import org.apache.avro.Schema
import org.apache.avro.file.CodecFactory
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder}
import org.apache.beam.sdk.extensions.smb.BucketMetadata.HashType
import org.apache.beam.sdk.extensions.smb.{AvroSortedBucketIO, TargetParallelism}
import org.apache.beam.sdk.values.TupleTag
Expand All @@ -58,13 +58,10 @@ object SortMergeBucketExample {
|""".stripMargin
)

def user(id: String, age: Int): GenericRecord = {
val gr = new GenericData.Record(UserDataSchema)
gr.put("userId", id)
gr.put("age", age)

gr
}
def user(id: String, age: Int): GenericRecord = new GenericRecordBuilder(UserDataSchema)
.set("userId", id)
.set("age", age)
.build()
}

object SortMergeBucketWriteExample {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,18 +20,18 @@ package com.spotify.scio.examples.extra
import com.spotify.scio.avro.AvroIO
import com.spotify.scio.io._
import com.spotify.scio.testing._
import org.apache.avro.generic.{GenericData, GenericRecord}
import org.apache.avro.generic.{GenericRecord, GenericRecordBuilder}

class MagnolifyAvroExampleTest extends PipelineSpec {
import MagnolifyAvroExample._

val textIn: Seq[String] = Seq("a b c d e", "a b a b")
val wordCount: Seq[(String, Long)] = Seq(("a", 3L), ("b", 3L), ("c", 1L), ("d", 1L), ("e", 1L))
val records: Seq[GenericRecord] = wordCount.map { kv =>
val r = new GenericData.Record(wordCountType.schema)
r.put("word", kv._1)
r.put("count", kv._2)
r
new GenericRecordBuilder(wordCountType.schema)
.set("word", kv._1)
.set("count", kv._2)
.build()
}
val textOut: Seq[String] = wordCount.map(kv => kv._1 + ": " + kv._2)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,10 +19,9 @@ package com.spotify.scio.extra.bigquery

import java.math.{BigDecimal => JBigDecimal}
import java.nio.ByteBuffer

import com.google.protobuf.ByteString
import com.spotify.scio.bigquery.TableRow
import org.apache.avro.generic.GenericData
import org.apache.avro.generic.GenericRecordBuilder
import org.apache.avro.generic.GenericData.EnumSymbol
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.io.BaseEncoding
import org.joda.time.{DateTime, LocalDate, LocalTime}
Expand Down Expand Up @@ -72,29 +71,24 @@ class ToTableRowTest extends AnyFlatSpec with Matchers {
it should "convert a GenericRecord to TableRow" in {
val enumSchema = AvroExample.SCHEMA$.getField("enumField").schema()

val nestedAvro = new GenericData.Record(NestedAvro.SCHEMA$)
nestedAvro.put("nestedField", "nestedValue")
val nestedAvro = new GenericRecordBuilder(NestedAvro.SCHEMA$)
.set("nestedField", "nestedValue")
.build()

val genericRecord = new GenericData.Record(AvroExample.SCHEMA$)
genericRecord.put("booleanField", true)
genericRecord.put("stringField", "someString")
genericRecord.put("doubleField", 1.0)
genericRecord.put("longField", 1L)
genericRecord.put("intField", 1)
genericRecord.put("floatField", 1f)
genericRecord.put(
"bytesField",
ByteBuffer.wrap(ByteString.copyFromUtf8("%20cフーバー").toByteArray)
)
genericRecord.put("arrayField", List(nestedAvro).asJava)
genericRecord.put("unionField", "someUnion")
genericRecord.put(
"mapField",
Map("mapKey" -> 1.0d).asJava
.asInstanceOf[java.util.Map[java.lang.CharSequence, java.lang.Double]]
)
genericRecord.put("enumField", new EnumSymbol(enumSchema, Kind.FOO.toString))
genericRecord.put("fixedField", new fixedType("%20cフーバー".getBytes()))
val genericRecord = new GenericRecordBuilder(AvroExample.SCHEMA$)
.set("booleanField", true)
.set("stringField", "someString")
.set("doubleField", 1.0)
.set("longField", 1L)
.set("intField", 1)
.set("floatField", 1f)
.set("bytesField", ByteBuffer.wrap(ByteString.copyFromUtf8("%20cフーバー").toByteArray))
.set("arrayField", List(nestedAvro).asJava)
.set("unionField", "someUnion")
.set("mapField", Map[CharSequence, java.lang.Double]("mapKey" -> 1.0d).asJava)
.set("enumField", new EnumSymbol(enumSchema, Kind.FOO.toString))
.set("fixedField", new fixedType("%20cフーバー".getBytes()))
.build()

AvroConverters.toTableRow(genericRecord) shouldEqual expectedOutput
}
Expand Down
Loading

0 comments on commit acfb416

Please sign in to comment.