diff --git a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroIO.scala b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroIO.scala index c254fd68d7..8b8e64d7ba 100644 --- a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroIO.scala +++ b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroIO.scala @@ -22,7 +22,7 @@ import com.spotify.scio.ScioContext import com.spotify.scio.coders.{Coder, CoderMaterializer} import com.spotify.scio.io.{ScioIO, Tap, TapOf, TapT} import com.spotify.scio.parquet.read.{ParquetRead, ParquetReadConfiguration, ReadSupportFactory} -import com.spotify.scio.parquet.{BeamInputFile, GcsConnectorUtil, ParquetConfiguration} +import com.spotify.scio.parquet.{GcsConnectorUtil, ParquetConfiguration} import com.spotify.scio.testing.TestDataManager import com.spotify.scio.util.{FilenamePolicySupplier, Functions, ScioUtil} import com.spotify.scio.values.SCollection @@ -43,7 +43,6 @@ import org.apache.hadoop.mapreduce.Job import org.apache.parquet.avro.{ AvroDataSupplier, AvroParquetInputFormat, - AvroParquetReader, AvroReadSupport, AvroWriteSupport, GenericDataSupplier @@ -61,11 +60,10 @@ final case class ParquetAvroIO[T: ClassTag: Coder](path: String) extends ScioIO[ override type WriteP = ParquetAvroIO.WriteParam override val tapT: TapT.Aux[T, T] = TapOf[T] - private val cls = ScioUtil.classOf[T] - override protected def read(sc: ScioContext, params: ReadP): SCollection[T] = { val bCoder = CoderMaterializer.beam(sc, Coder[T]) sc.pipeline.getCoderRegistry.registerCoderForClass(ScioUtil.classOf[T], bCoder) + params.setupConfig() params.read(sc, path)(Coder[T]) } @@ -101,7 +99,7 @@ final case class ParquetAvroIO[T: ClassTag: Coder](path: String) extends ScioIO[ )(ScioUtil.strippedPath(path), suffix) val dynamicDestinations = DynamicFileDestinations .constant(fp, SerializableFunctions.identity[T]) - val job = Job.getInstance(ParquetConfiguration.ofNullable(conf)) + val job = Job.getInstance(conf) if (isLocalRunner) GcsConnectorUtil.setCredentials(job) val sink = new ParquetAvroFileBasedSink[T]( @@ -116,9 +114,9 @@ final case class ParquetAvroIO[T: ClassTag: Coder](path: String) extends ScioIO[ } override protected def write(data: SCollection[T], params: WriteP): Tap[T] = { - val isAssignable = classOf[SpecificRecord].isAssignableFrom(cls) - val writerSchema = if (isAssignable) ReflectData.get().getSchema(cls) else params.schema - val conf = ParquetConfiguration.ofNullable(params.conf) + val avroClass = ScioUtil.classOf[T] + val isSpecific: Boolean = classOf[SpecificRecord] isAssignableFrom avroClass + val writerSchema = if (isSpecific) ReflectData.get().getSchema(avroClass) else params.schema data.applyInternal( parquetOut( @@ -127,7 +125,7 @@ final case class ParquetAvroIO[T: ClassTag: Coder](path: String) extends ScioIO[ params.suffix, params.numShards, params.compression, - conf, + ParquetConfiguration.ofNullable(params.conf), params.filenamePolicySupplier, params.prefix, params.shardNameTemplate, @@ -144,8 +142,6 @@ final case class ParquetAvroIO[T: ClassTag: Coder](path: String) extends ScioIO[ } object ParquetAvroIO { - private lazy val log = LoggerFactory.getLogger(getClass) - object ReadParam { val DefaultProjection: Schema = null val DefaultPredicate: FilterPredicate = null @@ -168,75 +164,70 @@ object ParquetAvroIO { conf: Configuration = ReadParam.DefaultConfiguration, suffix: String = null ) { + lazy val confOrDefault = ParquetConfiguration.ofNullable(conf) val avroClass: Class[A] = ScioUtil.classOf[A] val isSpecific: Boolean = classOf[SpecificRecord] isAssignableFrom avroClass val readSchema: Schema = if (isSpecific) ReflectData.get().getSchema(avroClass) else projection def read(sc: ScioContext, path: String)(implicit coder: Coder[T]): SCollection[T] = { - val jobConf = ParquetConfiguration.ofNullable(conf) + if (ParquetReadConfiguration.getUseSplittableDoFn(confOrDefault, sc.options)) { + readSplittableDoFn(sc, path) + } else { + readLegacy(sc, path) + } + } + + def setupConfig(): Unit = { + AvroReadSupport.setAvroReadSchema(confOrDefault, readSchema) + if (projection != null) { + AvroReadSupport.setRequestedProjection(confOrDefault, projection) + } + + if (predicate != null) { + ParquetInputFormat.setFilterPredicate(confOrDefault, predicate) + } // Needed to make GenericRecord read by parquet-avro work with Beam's // org.apache.beam.sdk.extensions.avro.coders.AvroCoder if (!isSpecific) { - jobConf.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, false) - - if (jobConf.get(AvroReadSupport.AVRO_DATA_SUPPLIER) == null) { - jobConf.setClass( + confOrDefault.setBoolean(AvroReadSupport.AVRO_COMPATIBILITY, false) + if (confOrDefault.get(AvroReadSupport.AVRO_DATA_SUPPLIER) == null) { + confOrDefault.setClass( AvroReadSupport.AVRO_DATA_SUPPLIER, classOf[GenericDataSupplier], classOf[AvroDataSupplier] ) } } - - if (ParquetReadConfiguration.getUseSplittableDoFn(jobConf, sc.options)) { - readSplittableDoFn(sc, jobConf, path) - } else { - readLegacy(sc, jobConf, path) - } } - private def readSplittableDoFn(sc: ScioContext, conf: Configuration, path: String)(implicit + private def readSplittableDoFn(sc: ScioContext, path: String)(implicit coder: Coder[T] ): SCollection[T] = { - AvroReadSupport.setAvroReadSchema(conf, readSchema) - if (projection != null) { - AvroReadSupport.setRequestedProjection(conf, projection) - } - if (predicate != null) { - ParquetInputFormat.setFilterPredicate(conf, predicate) - } - + val filePattern = ScioUtil.filePattern(path, suffix) val bCoder = CoderMaterializer.beam(sc, coder) val cleanedProjectionFn = ClosureCleaner.clean(projectionFn) sc.applyTransform( ParquetRead.read[A, T]( ReadSupportFactory.avro, - new SerializableConfiguration(conf), - path, + new SerializableConfiguration(confOrDefault), + filePattern, Functions.serializableFn(cleanedProjectionFn) ) ).setCoder(bCoder) } - private def readLegacy(sc: ScioContext, conf: Configuration, path: String)(implicit + private def readLegacy(sc: ScioContext, path: String)(implicit coder: Coder[T] ): SCollection[T] = { - val job = Job.getInstance(conf) - GcsConnectorUtil.setInputPaths(sc, job, path) + val job = Job.getInstance(confOrDefault) + val filePattern = ScioUtil.filePattern(path, suffix) + GcsConnectorUtil.setInputPaths(sc, job, filePattern) job.setInputFormatClass(classOf[AvroParquetInputFormat[T]]) job.getConfiguration.setClass("key.class", classOf[Void], classOf[Void]) job.getConfiguration.setClass("value.class", avroClass, avroClass) - AvroParquetInputFormat.setAvroReadSchema(job, readSchema) - - if (projection != null) { - AvroParquetInputFormat.setRequestedProjection(job, projection) - } - if (predicate != null) { - ParquetInputFormat.setFilterPredicate(job.getConfiguration, predicate) - } val g = ClosureCleaner.clean(projectionFn) // defeat closure val aCls = avroClass @@ -285,30 +276,3 @@ object ParquetAvroIO { tempDirectory: String = WriteParam.DefaultTempDirectory ) } - -final case class ParquetAvroTap[A, T: ClassTag: Coder]( - path: String, - params: ParquetAvroIO.ReadParam[A, T] -) extends Tap[T] { - override def value: Iterator[T] = { - val filePattern = ScioUtil.filePattern(path, params.suffix) - val xs = FileSystems.`match`(filePattern).metadata().asScala.toList - xs.iterator.flatMap { metadata => - val reader = AvroParquetReader - .builder[A](BeamInputFile.of(metadata.resourceId())) - .withConf(ParquetConfiguration.ofNullable(params.conf)) - .build() - new Iterator[T] { - private var current: A = reader.read() - override def hasNext: Boolean = current != null - override def next(): T = { - val r = params.projectionFn(current) - current = reader.read() - r - } - } - } - } - override def open(sc: ScioContext): SCollection[T] = - sc.read(ParquetAvroIO[T](path))(params) -} diff --git a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroSink.scala b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroSink.scala new file mode 100644 index 0000000000..8686ce37ad --- /dev/null +++ b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroSink.scala @@ -0,0 +1,55 @@ +/* + * Copyright 2024 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.parquet.avro + +import com.spotify.scio.parquet.ParquetOutputFile +import org.apache.avro.Schema +import org.apache.beam.sdk.io.FileIO +import org.apache.beam.sdk.io.hadoop.SerializableConfiguration +import org.apache.parquet.avro.AvroParquetWriter +import org.apache.parquet.hadoop.metadata.CompressionCodecName +import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetWriter} + +import java.nio.channels.WritableByteChannel + +class ParquetAvroSink[T]( + schema: Schema, + val compression: CompressionCodecName, + val conf: SerializableConfiguration +) extends FileIO.Sink[T] { + private val schemaString = schema.toString + private var writer: ParquetWriter[T] = _ + + override def open(channel: WritableByteChannel): Unit = { + val schema = new Schema.Parser().parse(schemaString) + // https://github.com/apache/parquet-mr/tree/master/parquet-hadoop#class-parquetoutputformat + val rowGroupSize = + conf.get.getInt(ParquetOutputFormat.BLOCK_SIZE, ParquetWriter.DEFAULT_BLOCK_SIZE) + writer = AvroParquetWriter + .builder[T](new ParquetOutputFile(channel)) + .withSchema(schema) + .withCompressionCodec(compression) + .withConf(conf.get) + .withRowGroupSize(rowGroupSize) + .build + } + + override def write(element: T): Unit = writer.write(element) + + override def flush(): Unit = writer.close() +} diff --git a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroTap.scala b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroTap.scala new file mode 100644 index 0000000000..ca0c55ea70 --- /dev/null +++ b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/ParquetAvroTap.scala @@ -0,0 +1,59 @@ +/* + * Copyright 2024 Spotify AB. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +package com.spotify.scio.parquet.avro + +import com.spotify.scio.ScioContext +import com.spotify.scio.coders.Coder +import com.spotify.scio.io.Tap +import com.spotify.scio.parquet.{BeamInputFile, ParquetConfiguration} +import com.spotify.scio.util.ScioUtil +import com.spotify.scio.values.SCollection +import org.apache.beam.sdk.io._ +import org.apache.parquet.avro.AvroParquetReader + +import scala.jdk.CollectionConverters._ +import scala.reflect.ClassTag + +final case class ParquetAvroTap[A, T: ClassTag: Coder]( + path: String, + params: ParquetAvroIO.ReadParam[A, T] +) extends Tap[T] { + override def value: Iterator[T] = { + val filePattern = ScioUtil.filePattern(path, params.suffix) + params.setupConfig() + + val xs = FileSystems.`match`(filePattern).metadata().asScala.toList + xs.iterator.flatMap { metadata => + val reader = AvroParquetReader + .builder[A](BeamInputFile.of(metadata.resourceId())) + .withConf(params.confOrDefault) + .build() + new Iterator[T] { + private var current: A = reader.read() + override def hasNext: Boolean = current != null + override def next(): T = { + val r = params.projectionFn(current) + current = reader.read() + r + } + } + } + } + override def open(sc: ScioContext): SCollection[T] = + sc.read(ParquetAvroIO[T](path))(params) +} diff --git a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/package.scala b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/package.scala index ad382b2302..fb8a787b24 100644 --- a/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/package.scala +++ b/scio-parquet/src/main/scala/com/spotify/scio/parquet/avro/package.scala @@ -18,14 +18,6 @@ package com.spotify.scio.parquet import com.spotify.scio.parquet.avro.syntax.Syntax -import org.apache.avro.Schema -import org.apache.beam.sdk.io.FileIO -import org.apache.beam.sdk.io.hadoop.SerializableConfiguration -import org.apache.parquet.avro.AvroParquetWriter -import org.apache.parquet.hadoop.metadata.CompressionCodecName -import org.apache.parquet.hadoop.{ParquetOutputFormat, ParquetWriter} - -import java.nio.channels.WritableByteChannel /** * Main package for Parquet Avro APIs. Import all. @@ -41,28 +33,4 @@ package object avro extends Syntax { /** Alias for `me.lyh.parquet.avro.Predicate`. */ val Predicate = me.lyh.parquet.avro.Predicate - - class ParquetAvroSink[T]( - schema: Schema, - val compression: CompressionCodecName, - val conf: SerializableConfiguration - ) extends FileIO.Sink[T] { - private val schemaString = schema.toString - private var writer: ParquetWriter[T] = _ - override def open(channel: WritableByteChannel): Unit = { - val schema = new Schema.Parser().parse(schemaString) - // https://github.com/apache/parquet-mr/tree/master/parquet-hadoop#class-parquetoutputformat - val rowGroupSize = - conf.get.getInt(ParquetOutputFormat.BLOCK_SIZE, ParquetWriter.DEFAULT_BLOCK_SIZE) - writer = AvroParquetWriter - .builder[T](new ParquetOutputFile(channel)) - .withSchema(schema) - .withCompressionCodec(compression) - .withConf(conf.get) - .withRowGroupSize(rowGroupSize) - .build - } - override def write(element: T): Unit = writer.write(element) - override def flush(): Unit = writer.close() - } } diff --git a/scio-parquet/src/test/scala/com/spotify/scio/parquet/avro/ParquetAvroIOTest.scala b/scio-parquet/src/test/scala/com/spotify/scio/parquet/avro/ParquetAvroIOTest.scala index b051740aab..ad963506df 100644 --- a/scio-parquet/src/test/scala/com/spotify/scio/parquet/avro/ParquetAvroIOTest.scala +++ b/scio-parquet/src/test/scala/com/spotify/scio/parquet/avro/ParquetAvroIOTest.scala @@ -22,6 +22,7 @@ import com.spotify.scio._ import com.spotify.scio.avro._ import com.spotify.scio.io.{ClosedTap, FileNamePolicySpec, ScioIOTest, TapSpec, TextIO} import com.spotify.scio.parquet.ParquetConfiguration +import com.spotify.scio.parquet.read.ParquetReadConfiguration import com.spotify.scio.testing._ import com.spotify.scio.util.FilenamePolicySupplier import com.spotify.scio.values.{SCollection, WindowOptions} @@ -33,9 +34,11 @@ import org.apache.beam.sdk.Pipeline.PipelineExecutionException import org.apache.beam.sdk.options.PipelineOptionsFactory import org.apache.beam.sdk.transforms.windowing.{BoundedWindow, IntervalWindow, PaneInfo} import org.apache.commons.io.FileUtils -import org.apache.parquet.avro.{AvroDataSupplier, AvroReadSupport, AvroWriteSupport} +import org.apache.hadoop.conf.Configuration +import org.apache.parquet.avro._ import org.joda.time.{DateTime, DateTimeFieldType, Duration, Instant} import org.scalatest.BeforeAndAfterAll +import org.scalatest.prop.TableDrivenPropertyChecks.{forAll => forAllCases, Table} import java.lang import java.nio.file.Files @@ -79,208 +82,278 @@ class ParquetAvroIOTest extends ScioIOSpec with TapSpec with BeforeAndAfterAll { override protected def afterAll(): Unit = FileUtils.deleteDirectory(testDir) - "ParquetAvroIO" should "work with specific records" in { + private def createConfig(splittable: Boolean): Configuration = { + val c = ParquetConfiguration.empty() + c.set(ParquetReadConfiguration.UseSplittableDoFn, splittable.toString) + c + } + + private val readConfigs = + Table( + ("config", "description"), + (() => createConfig(false), "legacy read"), + (() => createConfig(true), "splittable"), + (() => ParquetAvroIO.ReadParam.DefaultConfiguration, "default") + ) + + it should "work with specific records" in { val xs = (1 to 100).map(AvroUtils.newSpecificRecord) - testTap(xs)(_.saveAsParquetAvroFile(_))(".parquet") - testJobTest(xs)(ParquetAvroIO(_))( - _.parquetAvroFile[TestRecord](_).map(identity) - )(_.saveAsParquetAvroFile(_)) + + forAllCases(readConfigs) { case (c, _) => + testTap(xs)(_.saveAsParquetAvroFile(_))(".parquet") + testJobTest(xs)(ParquetAvroIO(_))( + _.parquetAvroFile[TestRecord](_, conf = c()).map(identity) + )(_.saveAsParquetAvroFile(_)) + } } it should "read specific records with projection" in { - val sc = ScioContext() - val projection = Projection[TestRecord](_.getIntField) - val data = sc.parquetAvroFile[TestRecord]( - path = testDir.getAbsolutePath, - projection = projection, - suffix = ".parquet" - ) - data.map(_.getIntField.toInt) should containInAnyOrder(1 to 10) - data.map(identity) should forAll[TestRecord] { r => - r.getLongField == null && r.getFloatField == null && r.getDoubleField == null && - r.getBooleanField == null && r.getStringField == null && r.getArrayField - .size() == 0 + forAllCases(readConfigs) { case (c, x) => + val sc = ScioContext() + val projection = Projection[TestRecord](_.getIntField) + val data = sc.parquetAvroFile[TestRecord]( + path = testDir.getAbsolutePath, + projection = projection, + suffix = ".parquet", + conf = c() + ) + + data.map(_.getIntField.toInt).debug() should containInAnyOrder(1 to 10) + data.map(identity) should forAll[TestRecord] { r => + r.getLongField == null && r.getFloatField == null && r.getDoubleField == null && + r.getBooleanField == null && r.getStringField == null && r.getArrayField + .size() == 0 + } + sc.run() } - sc.run() } it should "read specific records with predicate" in { - val sc = ScioContext() - val predicate = Predicate[TestRecord](_.getIntField <= 5) - val data = sc.parquetAvroFile[TestRecord]( - path = testDir.getAbsolutePath, - predicate = predicate, - suffix = ".parquet" - ) - val expected = specificRecords.filter(_.getIntField <= 5) - data.map(identity) should containInAnyOrder(expected) - sc.run() + forAllCases(readConfigs) { case (c, _) => + val sc = ScioContext() + val predicate = Predicate[TestRecord](_.getIntField <= 5) + val data = sc.parquetAvroFile[TestRecord]( + path = testDir.getAbsolutePath, + predicate = predicate, + suffix = ".parquet", + conf = c() + ) + val expected = specificRecords.filter(_.getIntField <= 5) + data.map(identity) should containInAnyOrder(expected) + sc.run() + } } it should "read specific records with projection and predicate" in { - val sc = ScioContext() - val projection = Projection[TestRecord](_.getIntField) - val predicate = Predicate[TestRecord](_.getIntField <= 5) - val data = sc.parquetAvroFile[TestRecord]( - path = testDir.getAbsolutePath, - projection = projection, - predicate = predicate, - suffix = ".parquet" - ) - data.map(_.getIntField.toInt) should containInAnyOrder(1 to 5) - data.map(identity) should forAll[TestRecord] { r => - r.getLongField == null && - r.getFloatField == null && - r.getDoubleField == null && - r.getBooleanField == null && - r.getStringField == null && - r.getArrayField.size() == 0 + forAllCases(readConfigs) { case (c, _) => + val sc = ScioContext() + val projection = Projection[TestRecord](_.getIntField) + val predicate = Predicate[TestRecord](_.getIntField <= 5) + val data = sc.parquetAvroFile[TestRecord]( + path = testDir.getAbsolutePath, + projection = projection, + predicate = predicate, + suffix = ".parquet", + conf = c() + ) + data.map(_.getIntField.toInt) should containInAnyOrder(1 to 5) + data.map(identity) should forAll[TestRecord] { r => + r.getLongField == null && + r.getFloatField == null && + r.getDoubleField == null && + r.getBooleanField == null && + r.getStringField == null && + r.getArrayField.size() == 0 + } + sc.run() } - sc.run() } it should "write and read SpecificRecords with default logical types" in withTempDir { dir => - val records = (1 to 10).map(_ => - TestLogicalTypes - .newBuilder() - .setTimestamp(DateTime.now()) - .setDecimal(BigDecimal.decimal(1.0).setScale(2).bigDecimal) - .build() - ) - - val sc1 = ScioContext() - sc1 - .parallelize(records) - .saveAsParquetAvroFile(path = dir.getAbsolutePath) - sc1.run() - - val sc2 = ScioContext() - sc2 - .parquetAvroFile[TestLogicalTypes]( - path = dir.getAbsolutePath, - suffix = ".parquet" + forAllCases(readConfigs) { case (readConf, testCase) => + val testCaseDir = new File(dir, testCase) + val records = (1 to 10).map(_ => + TestLogicalTypes + .newBuilder() + .setTimestamp(DateTime.now()) + .setDecimal(BigDecimal.decimal(1.0).setScale(2).bigDecimal) + .build() ) - .map(identity) should containInAnyOrder(records) - sc2.run() - } + val sc1 = ScioContext() + sc1 + .parallelize(records) + .saveAsParquetAvroFile(path = testCaseDir.getAbsolutePath) + sc1.run() + + val sc2 = ScioContext() + sc2 + .parquetAvroFile[TestLogicalTypes]( + path = testCaseDir.getAbsolutePath, + conf = readConf(), + suffix = ".parquet" + ) + .map(identity) + .debug() should containInAnyOrder(records) - it should "write and read GenericRecords with default logical types" in withTempDir { dir => - val records: Seq[GenericRecord] = (1 to 10).map { _ => - val gr = new GenericRecordBuilder(TestLogicalTypes.SCHEMA$) - gr.set("timestamp", DateTime.now()) - gr.set( - "decimal", - BigDecimal.decimal(1.0).setScale(2).bigDecimal - ) - gr.build() + sc2.run() } + } - implicit val coder = { - GenericData.get().addLogicalTypeConversion(new TimeConversions.TimestampConversion) - GenericData.get().addLogicalTypeConversion(new Conversions.DecimalConversion) - avroGenericRecordCoder(TestLogicalTypes.SCHEMA$) - } + it should "write and read GenericRecords with default logical types" in withTempDir { dir => + forAllCases(readConfigs) { case (readConf, testCase) => + val testCaseDir = new File(dir, testCase) + val records: Seq[GenericRecord] = (1 to 10).map { _ => + val gr = new GenericRecordBuilder(TestLogicalTypes.SCHEMA$) + gr.set("timestamp", DateTime.now()) + gr.set( + "decimal", + BigDecimal.decimal(1.0).setScale(2).bigDecimal + ) + gr.build() + } - val sc1 = ScioContext() - sc1 - .parallelize(records) - .saveAsParquetAvroFile( - path = dir.getAbsolutePath, - schema = TestLogicalTypes.SCHEMA$ - ) - sc1.run() + implicit val coder = { + GenericData.get().addLogicalTypeConversion(new TimeConversions.TimestampConversion) + GenericData.get().addLogicalTypeConversion(new Conversions.DecimalConversion) + avroGenericRecordCoder(TestLogicalTypes.SCHEMA$) + } - val sc2 = ScioContext() - sc2 - .parquetAvroFile[GenericRecord]( - path = dir.getAbsolutePath, - projection = TestLogicalTypes.SCHEMA$, - suffix = ".parquet" - ) - .map(identity) should containInAnyOrder(records) + val sc1 = ScioContext() + sc1 + .parallelize(records) + .saveAsParquetAvroFile( + path = testCaseDir.getAbsolutePath, + schema = TestLogicalTypes.SCHEMA$ + ) + sc1.run() + + val sc2 = ScioContext() + sc2 + .parquetAvroFile[GenericRecord]( + path = testCaseDir.getAbsolutePath, + projection = TestLogicalTypes.SCHEMA$, + conf = readConf(), + suffix = ".parquet" + ) + .map(identity) should containInAnyOrder(records) - sc2.run() + sc2.run() + } } it should "write and read SpecificRecords with custom logical types" in withTempDir { dir => - val records = - (1 to 10).map(_ => - TestLogicalTypes - .newBuilder() - .setTimestamp(DateTime.now()) - .setDecimal(BigDecimal.decimal(1.0).setScale(2).bigDecimal) - .build() - ) - - val sc1 = ScioContext() - sc1 - .parallelize(records) - .saveAsParquetAvroFile( - path = dir.getAbsolutePath, - conf = ParquetConfiguration.of( - AvroWriteSupport.AVRO_DATA_SUPPLIER -> classOf[CustomLogicalTypeSupplier] + forAllCases(readConfigs) { case (readConf, testCase) => + val testCaseDir = new File(dir, testCase) + val records = + (1 to 10).map(_ => + TestLogicalTypes + .newBuilder() + .setTimestamp(DateTime.now()) + .setDecimal(BigDecimal.decimal(1.0).setScale(2).bigDecimal) + .build() ) - ) - sc1.run() - val sc2 = ScioContext() - sc2 - .parquetAvroFile[TestLogicalTypes]( - path = dir.getAbsolutePath, - conf = ParquetConfiguration.of( - AvroReadSupport.AVRO_DATA_SUPPLIER -> classOf[CustomLogicalTypeSupplier] - ), - suffix = ".parquet" - ) - .map(identity) should containInAnyOrder(records) + val sc1 = ScioContext() + sc1 + .parallelize(records) + .saveAsParquetAvroFile( + path = testCaseDir.getAbsolutePath + ) + sc1.run() + + val sc2 = ScioContext() + sc2 + .parquetAvroFile[TestLogicalTypes]( + path = testCaseDir.getAbsolutePath, + conf = readConf(), + suffix = ".parquet" + ) + .map(identity) should containInAnyOrder(records) - sc2.run() - () + sc2.run() + () + } } it should "read with incomplete projection" in withTempDir { dir => - val sc1 = ScioContext() - val nestedRecords = - (1 to 10).map(x => new Account(x, x.toString, x.toString, x.toDouble, AccountStatus.Active)) - sc1 - .parallelize(nestedRecords) - .saveAsParquetAvroFile(dir.getAbsolutePath) - sc1.run() - - val sc2 = ScioContext() - val projection = Projection[Account](_.getName) - val data = sc2.parquetAvroFile[Account]( - path = dir.getAbsolutePath, - projection = projection, - suffix = ".parquet" - ) - val expected = nestedRecords.map(_.getName.toString) - data.map(_.getName.toString) should containInAnyOrder(expected) - data.flatMap(a => Some(a.getName.toString)) should containInAnyOrder(expected) - sc2.run() + forAllCases(readConfigs) { case (readConf, testCase) => + val testCaseDir = new File(dir, testCase) + val sc1 = ScioContext() + val nestedRecords = + (1 to 10).map(x => new Account(x, x.toString, x.toString, x.toDouble, AccountStatus.Active)) + sc1 + .parallelize(nestedRecords) + .saveAsParquetAvroFile(testCaseDir.getAbsolutePath) + sc1.run() + + val sc2 = ScioContext() + val projection = Projection[Account](_.getName) + val data = sc2.parquetAvroFile[Account]( + path = testCaseDir.getAbsolutePath, + projection = projection, + conf = readConf(), + suffix = ".parquet" + ) + val expected = nestedRecords.map(_.getName.toString) + data.map(_.getName.toString) should containInAnyOrder(expected) + data.flatMap(a => Some(a.getName.toString)) should containInAnyOrder(expected) + sc2.run() + } } it should "read/write generic records" in withTempDir { dir => - val genericRecords = (1 to 100).map(AvroUtils.newGenericRecord) - val sc1 = ScioContext() - implicit val coder = avroGenericRecordCoder(AvroUtils.schema) - sc1 - .parallelize(genericRecords) - .saveAsParquetAvroFile(dir.getAbsolutePath, numShards = 1, schema = AvroUtils.schema) - sc1.run() + forAllCases(readConfigs) { case (readConf, testCase) => + val testCaseDir = new File(dir, testCase) + val genericRecords = (1 to 100).map(AvroUtils.newGenericRecord) + val sc1 = ScioContext() + implicit val coder = avroGenericRecordCoder(AvroUtils.schema) + sc1 + .parallelize(genericRecords) + .saveAsParquetAvroFile( + testCaseDir.getAbsolutePath, + numShards = 1, + schema = AvroUtils.schema + ) + sc1.run() - val files = dir.listFiles() - files.length shouldBe 1 + val files = testCaseDir.listFiles() + files.map(_.isDirectory).length shouldBe 1 + + val sc2 = ScioContext() + val data: SCollection[GenericRecord] = sc2.parquetAvroFile[GenericRecord]( + path = testCaseDir.getAbsolutePath, + projection = AvroUtils.schema, + conf = readConf(), + suffix = ".parquet" + ) + data should containInAnyOrder(genericRecords) + sc2.run() + } + } - val sc2 = ScioContext() - val data: SCollection[GenericRecord] = sc2.parquetAvroFile[GenericRecord]( - path = dir.getAbsolutePath, - projection = AvroUtils.schema, - suffix = ".parquet" + class TestRecordProjection private (str: String) {} + + "tap" should "use projection schema and GenericDataSupplier" in { + val schema = new Schema.Parser().parse( + """ + |{ + |"type":"record", + |"name":"TestRecordProjection", + |"namespace":"com.spotify.scio.parquet.avro.ParquetAvroIOTest$", + |"fields":[{"name":"int_field","type":["null", "int"]}]} + |""".stripMargin ) - data should containInAnyOrder(genericRecords) - sc2.run() + + implicit val coder = avroGenericRecordCoder(schema) + + ParquetAvroTap( + s"${testDir.toPath}", + ParquetAvroIO.ReadParam(identity[GenericRecord], schema, suffix = "*.parquet") + ).value.foreach { gr => + gr.get("int_field") should not be null + gr.get("string_field") should be(null) + } } it should "write windowed generic records to dynamic destinations" in withTempDir { dir => @@ -371,7 +444,9 @@ class ParquetAvroIOTest extends ScioIOSpec with TapSpec with BeforeAndAfterAll { AvroUtils.schema, null ) + val tap = ParquetAvroTap(files.head.getAbsolutePath, params) + tap.value.toList should contain theSameElementsAs genericRecords } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java index 159503b3d4..dc9a6361f5 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/AvroSortedBucketIO.java @@ -247,21 +247,24 @@ public Read withPredicate(Predicate predicate) { return toBuilder().setPredicate(predicate).build(); } + @SuppressWarnings("unchecked") + @Override + FileOperations getFileOperations() { + return getRecordClass() == null + ? (AvroFileOperations) AvroFileOperations.of(getSchema()) + : (AvroFileOperations) + AvroFileOperations.of((Class) getRecordClass()); + } + @Override public SortedBucketSource.BucketedInput toBucketedInput( final SortedBucketSource.Keying keying) { - @SuppressWarnings("unchecked") - final AvroFileOperations fileOperations = - getRecordClass() == null - ? (AvroFileOperations) AvroFileOperations.of(getSchema()) - : (AvroFileOperations) - AvroFileOperations.of((Class) getRecordClass()); return SortedBucketSource.BucketedInput.of( keying, getTupleTag(), getInputDirectories(), getFilenameSuffix(), - fileOperations, + getFileOperations(), getPredicate()); } } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java index c196446daf..1fc3e4c537 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/JsonSortedBucketIO.java @@ -156,6 +156,11 @@ public Read withPredicate(Predicate predicate) { return toBuilder().setPredicate(predicate).build(); } + @Override + FileOperations getFileOperations() { + return JsonFileOperations.of(getCompression()); + } + @Override public BucketedInput toBucketedInput(final SortedBucketSource.Keying keying) { return BucketedInput.of( @@ -163,7 +168,7 @@ public BucketedInput toBucketedInput(final SortedBucketSource.Keying k getTupleTag(), getInputDirectories(), getFilenameSuffix(), - JsonFileOperations.of(getCompression()), + getFileOperations(), getPredicate()); } } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperations.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperations.java index 650e0aca64..a86e6bd18c 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperations.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperations.java @@ -35,12 +35,7 @@ import org.apache.beam.sdk.util.MimeTypes; import org.apache.hadoop.conf.Configuration; import org.apache.hadoop.util.ReflectionUtils; -import org.apache.parquet.avro.AvroDataSupplier; -import org.apache.parquet.avro.AvroParquetReader; -import org.apache.parquet.avro.AvroParquetWriter; -import org.apache.parquet.avro.AvroReadSupport; -import org.apache.parquet.avro.AvroWriteSupport; -import org.apache.parquet.avro.SpecificDataSupplier; +import org.apache.parquet.avro.*; import org.apache.parquet.filter2.compat.FilterCompat; import org.apache.parquet.filter2.predicate.FilterPredicate; import org.apache.parquet.hadoop.ParquetReader; @@ -109,7 +104,7 @@ public void populateDisplayData(DisplayData.Builder builder) { @Override protected Reader createReader() { - return new ParquetAvroReader<>(schemaSupplier, projectionSupplier, conf, predicate); + return new ParquetAvroReader<>(schemaSupplier, projectionSupplier, conf, predicate, recordClass); } @Override @@ -139,6 +134,7 @@ private static class ParquetAvroReader extends FileOperations.Reader recordClass; private transient ParquetReader reader; private transient ValueT current; @@ -146,11 +142,13 @@ private ParquetAvroReader( SerializableSchemaSupplier readSchemaSupplier, SerializableSchemaSupplier projectionSchemaSupplier, SerializableConfiguration conf, - FilterPredicate predicate) { + FilterPredicate predicate, + Class recordClass) { this.readSchemaSupplier = readSchemaSupplier; this.projectionSchemaSupplier = projectionSchemaSupplier; this.conf = conf; this.predicate = predicate; + this.recordClass = recordClass; } @Override @@ -163,6 +161,14 @@ public void prepareRead(ReadableByteChannel channel) throws IOException { AvroReadSupport.setRequestedProjection(configuration, projectionSchemaSupplier.get()); } + if (recordClass == null && configuration.get(AvroReadSupport.AVRO_DATA_SUPPLIER) == null) { + configuration.setClass( + AvroReadSupport.AVRO_DATA_SUPPLIER, + GenericDataSupplier.class, + AvroDataSupplier.class + ); + } + ParquetReader.Builder builder = AvroParquetReader.builder(new ParquetInputFile(channel)).withConf(configuration); if (predicate != null) { diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java index 14e752a9aa..d05340275d 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/ParquetAvroSortedBucketIO.java @@ -285,24 +285,26 @@ public Read withProjection(Schema projection) { @Override public BucketedInput toBucketedInput(final SortedBucketSource.Keying keying) { + return BucketedInput.of( + keying, + getTupleTag(), + getInputDirectories(), + getFilenameSuffix(), + getFileOperations(), + getPredicate()); + } + + @Override + FileOperations getFileOperations() { ParquetAvroFileOperations fileOperations = getRecordClass() == null ? (ParquetAvroFileOperations) ParquetAvroFileOperations.of(getSchema()) : ParquetAvroFileOperations.of(getRecordClass()); - fileOperations = - fileOperations + return fileOperations .withFilterPredicate(getFilterPredicate()) .withProjection(getProjection()) .withConfiguration(getConfiguration()); - - return BucketedInput.of( - keying, - getTupleTag(), - getInputDirectories(), - getFilenameSuffix(), - fileOperations, - getPredicate()); } } @@ -562,7 +564,9 @@ FileOperations getFileOperations() { ? (ParquetAvroFileOperations) ParquetAvroFileOperations.of(getSchema()) : ParquetAvroFileOperations.of(getRecordClass()); - return fileOperations.withConfiguration(getConfiguration()); + return fileOperations + .withConfiguration(getConfiguration()) + .withCompression(getCompression()); } @Override diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java index af04b9f58d..dd95956d9b 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/SortedBucketIO.java @@ -499,6 +499,8 @@ public abstract static class TransformOutput implements Serializable public abstract static class Read implements Serializable { public abstract TupleTag getTupleTag(); + abstract FileOperations getFileOperations(); + public abstract BucketedInput toBucketedInput(SortedBucketSource.Keying keying); } diff --git a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java index 3f3b5d8d64..8c9b66a698 100644 --- a/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java +++ b/scio-smb/src/main/java/org/apache/beam/sdk/extensions/smb/TensorFlowBucketIO.java @@ -169,6 +169,11 @@ public Read withPredicate(Predicate predicate) { return toBuilder().setPredicate(predicate).build(); } + @Override + FileOperations getFileOperations() { + return TensorFlowFileOperations.of(getCompression()); + } + @Override public BucketedInput toBucketedInput(final SortedBucketSource.Keying keying) { return BucketedInput.of( @@ -176,7 +181,7 @@ public BucketedInput toBucketedInput(final SortedBucketSource.Keying ke getTupleTag(), getInputDirectories(), getFilenameSuffix(), - TensorFlowFileOperations.of(getCompression()), + getFileOperations(), getPredicate()); } } diff --git a/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala b/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala index bc9b512db7..5a3f59068c 100644 --- a/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala +++ b/scio-smb/src/main/scala/org/apache/beam/sdk/extensions/smb/ParquetTypeSortedBucketIO.scala @@ -90,16 +90,18 @@ object ParquetTypeSortedBucketIO { override def toBucketedInput( keying: SortedBucketSource.Keying ): SortedBucketSource.BucketedInput[T] = { - val fileOperations = ParquetTypeFileOperations[T](filterPredicate, configuration) BucketedInput.of( keying, getTupleTag, inputDirectories.asJava, filenameSuffix, - fileOperations, + getFileOperations, predicate ) } + + override def getFileOperations: FileOperations[T] = + ParquetTypeFileOperations[T](filterPredicate, configuration) } case class Write[K1: ClassTag, K2: ClassTag, T: ClassTag: Coder: ParquetType]( diff --git a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperationsTest.java b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperationsTest.java index 439d33929c..644e3739d5 100644 --- a/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperationsTest.java +++ b/scio-smb/src/test/java/org/apache/beam/sdk/extensions/smb/ParquetAvroFileOperationsTest.java @@ -55,7 +55,8 @@ public class ParquetAvroFileOperationsTest { private static final Schema USER_SCHEMA = SchemaBuilder.record("User") - .namespace("org.apache.beam.sdk.extensions.smb.avro") + // intentionally set this namespace for testGenericRecord + .namespace("org.apache.beam.sdk.extensions.smb.ParquetAvroFileOperationsTest$") .fields() .name("name") .type() @@ -77,6 +78,12 @@ public class ParquetAvroFileOperationsTest { .build()) .collect(Collectors.toList()); + // Intentionally avoid no-arg ctor to verify this class is not attempted to instantiate + static class User { + User(String str) { + } + } + @Test public void testGenericRecord() throws Exception { final ResourceId file =