diff --git a/.github/workflows/migration.yml b/.github/workflows/migration.yml index 0b7decf899..9e404d47fc 100644 --- a/.github/workflows/migration.yml +++ b/.github/workflows/migration.yml @@ -23,7 +23,7 @@ jobs: # - scio-extra/compile # - scio-extra/test - scio-core/compile - # - scio-core/test + - scio-core/test # - scio-examples/compile # - scio-examples/test # - scio-redis/compile @@ -54,5 +54,7 @@ jobs: uses: coursier/cache-action@v5 - name: java 8 setup uses: olafurpg/setup-scala@v10 - - name: Compile - run: sbt "++3.0.0-RC2;${{ matrix.task }}" +# - name: Scala 2 +# run: sbt "++2.13.5;${{ matrix.task }}" + - name: Scala 3 + run: sbt "++3.0.0;${{ matrix.task }}" diff --git a/build.sbt b/build.sbt index ea81fb5a7e..2cb6b8c86a 100644 --- a/build.sbt +++ b/build.sbt @@ -25,7 +25,7 @@ import de.heikoseeberger.sbtheader.CommentCreator ThisBuild / turbo := true -val scala3Version = "3.0.0-RC2" +val scala3Version = "3.0.0" val algebirdVersion = "0.13.7" val algebraVersion = "2.2.2" val annoy4sVersion = "0.10.0" @@ -97,7 +97,7 @@ val protobufVersion = "3.15.8" val scalacheckVersion = "1.15.4" val scalaMacrosVersion = "2.1.1" val scalatestplusVersion = "3.1.0.0-RC2" -val scalatestVersion = "3.2.8" +val scalatestVersion = "3.2.9" val shapelessVersion = "2.3.4" val slf4jVersion = "1.7.30" val sparkeyVersion = "3.2.1" @@ -162,12 +162,12 @@ val commonSettings = Def Seq(Tests.Argument(TestFrameworks.ScalaTest, "-l", "org.scalatest.tags.Slow")) } }, - coverageExcludedPackages := (Seq( - "com\\.spotify\\.scio\\.examples\\..*", - "com\\.spotify\\.scio\\.repl\\..*", - "com\\.spotify\\.scio\\.util\\.MultiJoin" - ) ++ (2 to 10).map(x => s"com\\.spotify\\.scio\\.sql\\.Query${x}")).mkString(";"), - coverageHighlighting := true, + //coverageExcludedPackages := (Seq( + // "com\\.spotify\\.scio\\.examples\\..*", + // "com\\.spotify\\.scio\\.repl\\..*", + // "com\\.spotify\\.scio\\.util\\.MultiJoin" + //) ++ (2 to 10).map(x => s"com\\.spotify\\.scio\\.sql\\.Query${x}")).mkString(";"), + //coverageHighlighting := true, licenses := Seq("Apache 2" -> url("http://www.apache.org/licenses/LICENSE-2.0.txt")), homepage := Some(url("https://github.com/spotify/scio")), scmInfo := Some( @@ -584,7 +584,8 @@ lazy val `scio-macros`: Project = project libraryDependencies ++= Seq( "com.esotericsoftware" % "kryo-shaded" % kryoVersion, "org.apache.beam" % "beam-sdks-java-extensions-sql" % beamVersion, - "org.apache.avro" % "avro" % avroVersion + "org.apache.avro" % "avro" % avroVersion, + "org.scalatest" %% "scalatest" % scalatestVersion % Test ), // Scala 2 dependencies libraryDependencies ++= { diff --git a/project/build.properties b/project/build.properties index f0be67b9f7..67d27a1dfe 100644 --- a/project/build.properties +++ b/project/build.properties @@ -1 +1 @@ -sbt.version=1.5.1 +sbt.version=1.5.3 diff --git a/project/plugins.sbt b/project/plugins.sbt index 5ffb63776d..e86d18f395 100644 --- a/project/plugins.sbt +++ b/project/plugins.sbt @@ -3,7 +3,7 @@ addSbtPlugin("com.eed3si9n" % "sbt-assembly" % "0.14.10") addSbtPlugin("com.eed3si9n" % "sbt-unidoc" % "0.4.3") addSbtPlugin("com.github.sbt" % "sbt-protobuf" % "0.7.0") addSbtPlugin("com.typesafe.sbt" % "sbt-ghpages" % "0.6.3") -addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.7.2") +//addSbtPlugin("org.scoverage" % "sbt-scoverage" % "1.7.2") addSbtPlugin("pl.project13.scala" % "sbt-jmh" % "0.4.0") addSbtPlugin("com.typesafe" % "sbt-mima-plugin" % "0.8.1") addSbtPlugin("com.eed3si9n" % "sbt-buildinfo" % "0.10.0") diff --git a/scio-avro/src/main/scala/com/spotify/scio/avro/AvroIO.scala b/scio-avro/src/main/scala/com/spotify/scio/avro/AvroIO.scala index 617b9da018..3459d98a1b 100644 --- a/scio-avro/src/main/scala/com/spotify/scio/avro/AvroIO.scala +++ b/scio-avro/src/main/scala/com/spotify/scio/avro/AvroIO.scala @@ -250,7 +250,7 @@ object AvroIO { private[avro] val DefaultMetadata: Map[String, AnyRef] = Map.empty } - final case class WriteParam private ( + final case class WriteParam private[avro] ( numShards: Int = WriteParam.DefaultNumShards, private val _suffix: String = WriteParam.DefaultSuffix, codec: CodecFactory = WriteParam.DefaultCodec, diff --git a/scio-avro/src/test/scala/com/spotify/scio/avro/types/ConverterProviderSpec.scala b/scio-avro/src/test/scala/com/spotify/scio/avro/types/ConverterProviderSpec.scala index 7295d15a22..7e00879e9b 100644 --- a/scio-avro/src/test/scala/com/spotify/scio/avro/types/ConverterProviderSpec.scala +++ b/scio-avro/src/test/scala/com/spotify/scio/avro/types/ConverterProviderSpec.scala @@ -20,7 +20,6 @@ package com.spotify.scio.avro.types import cats.Eq import cats.instances.all._ import com.google.protobuf.ByteString -import magnolify.cats.semiauto.EqDerivation import magnolify.scalacheck.auto._ import org.scalacheck._ import org.scalatest.propspec.AnyPropSpec @@ -41,22 +40,22 @@ class ConverterProviderSpec extends AnyPropSpec with ScalaCheckDrivenPropertyChe implicit val eqByteString: Eq[ByteString] = Eq.instance[ByteString](_ == _) property("round trip basic primitive types") { - forAll { r1: BasicFields => + forAll { (r1: BasicFields) => val r2 = AvroType.fromGenericRecord[BasicFields](AvroType.toGenericRecord[BasicFields](r1)) - EqDerivation[BasicFields].eqv(r1, r2) shouldBe true + EqGen.of[BasicFields].eqv(r1, r2) shouldBe true } } property("round trip optional primitive types") { - forAll { r1: OptionalFields => + forAll { (r1: OptionalFields) => val r2 = AvroType.fromGenericRecord[OptionalFields](AvroType.toGenericRecord[OptionalFields](r1)) - EqDerivation[OptionalFields].eqv(r1, r2) shouldBe true + EqGen.of[OptionalFields].eqv(r1, r2) shouldBe true } } property("skip null optional primitive types") { - forAll { o: OptionalFields => + forAll { (o: OptionalFields) => val r = AvroType.toGenericRecord[OptionalFields](o) // GenericRecord object should only contain a key if the corresponding Option[T] is defined o.boolF.isDefined shouldBe (r.get("boolF") != null) @@ -70,37 +69,37 @@ class ConverterProviderSpec extends AnyPropSpec with ScalaCheckDrivenPropertyChe } property("round trip primitive type arrays") { - forAll { r1: ArrayFields => + forAll { (r1: ArrayFields) => val r2 = AvroType.fromGenericRecord[ArrayFields](AvroType.toGenericRecord[ArrayFields](r1)) - EqDerivation[ArrayFields].eqv(r1, r2) shouldBe true + EqGen.of[ArrayFields].eqv(r1, r2) shouldBe true } } property("round trip primitive type maps") { - forAll { r1: MapFields => + forAll { (r1: MapFields) => val r2 = AvroType.fromGenericRecord[MapFields](AvroType.toGenericRecord[MapFields](r1)) - EqDerivation[MapFields].eqv(r1, r2) shouldBe true + EqGen.of[MapFields].eqv(r1, r2) shouldBe true } } property("round trip required nested types") { - forAll { r1: NestedFields => + forAll { (r1: NestedFields) => val r2 = AvroType.fromGenericRecord[NestedFields](AvroType.toGenericRecord[NestedFields](r1)) - EqDerivation[NestedFields].eqv(r1, r2) shouldBe true + EqGen.of[NestedFields].eqv(r1, r2) shouldBe true } } property("round trip optional nested types") { - forAll { r1: OptionalNestedFields => + forAll { (r1: OptionalNestedFields) => val r2 = AvroType.fromGenericRecord[OptionalNestedFields]( AvroType.toGenericRecord[OptionalNestedFields](r1) ) - EqDerivation[OptionalNestedFields].eqv(r1, r2) shouldBe true + EqGen.of[OptionalNestedFields].eqv(r1, r2) shouldBe true } } property("skip null optional nested types") { - forAll { o: OptionalNestedFields => + forAll { (o: OptionalNestedFields) => val r = AvroType.toGenericRecord[OptionalNestedFields](o) // TableRow object should only contain a key if the corresponding Option[T] is defined o.basic.isDefined shouldBe (r.get("basic") != null) @@ -111,28 +110,28 @@ class ConverterProviderSpec extends AnyPropSpec with ScalaCheckDrivenPropertyChe } property("round trip nested type arrays") { - forAll { r1: ArrayNestedFields => + forAll { (r1: ArrayNestedFields) => val r2 = AvroType.fromGenericRecord[ArrayNestedFields]( AvroType.toGenericRecord[ArrayNestedFields](r1) ) - EqDerivation[ArrayNestedFields].eqv(r1, r2) shouldBe true + EqGen.of[ArrayNestedFields].eqv(r1, r2) shouldBe true } } // FIXME: can't derive Eq for this // property("round trip nested type maps") { -// forAll { r1: MapNestedFields => +// forAll { (r1: MapNestedFields) => // val r2 = // AvroType.fromGenericRecord[MapNestedFields](AvroType.toGenericRecord[MapNestedFields](r1)) -// EqDerivation[MapNestedFields].eqv(r1, r2) shouldBe true +// EqGen.of[MapNestedFields].eqv(r1, r2) shouldBe true // } // } property("round trip byte array types") { - forAll { r1: ByteArrayFields => + forAll { (r1: ByteArrayFields) => val r2 = AvroType.fromGenericRecord[ByteArrayFields](AvroType.toGenericRecord[ByteArrayFields](r1)) - EqDerivation[ByteArrayFields].eqv(r1, r2) shouldBe true + EqGen.of[ByteArrayFields].eqv(r1, r2) shouldBe true } } } diff --git a/scio-core/src/main/scala-3/com/spotify/scio/coders/instances/kryo/JTraversableSerializer.scala b/scio-core/src/main/scala-3/com/spotify/scio/coders/instances/kryo/JTraversableSerializer.scala index e7fc5c1ef3..d89bbc7281 100644 --- a/scio-core/src/main/scala-3/com/spotify/scio/coders/instances/kryo/JTraversableSerializer.scala +++ b/scio-core/src/main/scala-3/com/spotify/scio/coders/instances/kryo/JTraversableSerializer.scala @@ -73,7 +73,7 @@ abstract private[coders] class JWrapperCBF[T] extends Factory[T, Iterable[T]] { override def fromSpecific(it: IterableOnce[T]): Iterable[T] = { val b = new JIterableWrapperBuilder - it.foreach(b += _) + it.iterator.foreach(b += _) b.result() } diff --git a/scio-core/src/main/scala-3/com/spotify/scio/schemas/To.scala b/scio-core/src/main/scala-3/com/spotify/scio/schemas/To.scala index 1809cef330..db5e57b263 100644 --- a/scio-core/src/main/scala-3/com/spotify/scio/schemas/To.scala +++ b/scio-core/src/main/scala-3/com/spotify/scio/schemas/To.scala @@ -18,15 +18,165 @@ package com.spotify.scio.schemas import org.apache.beam.sdk.schemas.{SchemaCoder, Schema => BSchema} +import BSchema.{ FieldType => BFieldType } import scala.compiletime._ import scala.deriving._ import scala.quoted._ +import scala.reflect.ClassTag +import scala.collection.mutable +import com.spotify.scio.IsJavaBean.checkGetterAndSetters object ToMacro { - def safeImpl[I, O](si: Expr[Schema[I]])(implicit q: Quotes): Expr[To[I, O]] = { - ??? + + def safeImpl[I: scala.quoted.Type, O: scala.quoted.Type]( + iSchema: Expr[Schema[I]], + oSchema: Expr[Schema[O]] + )(using Quotes): Expr[To[I, O]] = { + import scala.quoted.quotes.reflect.{report, TypeRepr} + + (interpret[I] , interpret[O]) match { + case (None, None) => report.throwError( + s""" + |Could not interpret input schema: + | ${iSchema.show} + |Could not interpret output schema: + | ${oSchema.show} + |""".stripMargin + ) + case (None, _) => report.throwError("Could not interpret input schema: " + iSchema.show) + case (_, None) => report.throwError("Could not interpret output schema: " + oSchema.show) + case (Some(sIn), Some(sOut)) => + val schemaOut: BSchema = SchemaMaterializer.fieldType(sOut).getRowSchema() + val schemaIn: BSchema = SchemaMaterializer.fieldType(sIn).getRowSchema() + val classTagOpt = Expr.summon[ClassTag[O]] + if (classTagOpt.isEmpty) { + report.throwError(s"Could not summon Expr[ClassTag[${TypeRepr.of[O].show}]]") + } + val classTag = classTagOpt.get + To.checkCompatibility(schemaIn, schemaOut)('{ To.unchecked[I, O](using $iSchema, $oSchema, $classTag) }) + .fold(message => report.throwError(message), identity) + } } + + private def sequence[T](ls: List[Option[T]]): Option[List[T]] = + if ls.exists(_.isEmpty) then None + else Some(ls.collect { case Some(x) => x }) + + private def interpret[T: scala.quoted.Type](using Quotes): Option[Schema[T]] = + Type.of[T] match { + case '[java.lang.Byte] => Some(Schema.jByteSchema.asInstanceOf[Schema[T]]) + case '[Array[java.lang.Byte]] => Some(Schema.jBytesSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Short] => Some(Schema.jShortSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Integer] => Some(Schema.jIntegerSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Long] => Some(Schema.jLongSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Float] => Some(Schema.jFloatSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Double] => Some(Schema.jDoubleSchema.asInstanceOf[Schema[T]]) + case '[java.math.BigDecimal] => Some(Schema.jBigDecimalSchema.asInstanceOf[Schema[T]]) + case '[java.lang.Boolean] => Some(Schema.jBooleanSchema.asInstanceOf[Schema[T]]) + case '[java.util.List[u]] => + for (itemSchema) <- interpret[u] + yield Schema.jListSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[java.util.ArrayList[u]] => + for (itemSchema) <- interpret[u] + yield Schema.jArrayListSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[java.util.Map[k, v]] => + for { + keySchema <- interpret[k] + valueSchema <- interpret[v] + } yield Schema.jMapSchema(using keySchema, valueSchema).asInstanceOf[Schema[T]] + // TODO javaBeanSchema + // TODO javaEnumSchema + case '[java.time.LocalDate] => Some(Schema.jLocalDate.asInstanceOf[Schema[T]]) + + case '[String] => Some(Schema.stringSchema.asInstanceOf[Schema[T]]) + case '[Byte] => Some(Schema.byteSchema.asInstanceOf[Schema[T]]) + case '[Array[Byte]] => Some(Schema.bytesSchema.asInstanceOf[Schema[T]]) + case '[Short] => Some(Schema.sortSchema.asInstanceOf[Schema[T]]) + case '[Int] => Some(Schema.intSchema.asInstanceOf[Schema[T]]) + case '[Long] => Some(Schema.longSchema.asInstanceOf[Schema[T]]) + case '[Float] => Some(Schema.floatSchema.asInstanceOf[Schema[T]]) + case '[Double] => Some(Schema.doubleSchema.asInstanceOf[Schema[T]]) + case '[BigDecimal] => Some(Schema.bigDecimalSchema.asInstanceOf[Schema[T]]) + case '[Boolean] => Some(Schema.booleanSchema.asInstanceOf[Schema[T]]) + case '[Option[u]] => + for (itemSchema <- interpret[u]) + yield Schema.optionSchema(using itemSchema).asInstanceOf[Schema[T]] + // TODO Array[T] + case '[List[u]] => + for (itemSchema <- interpret[u]) + yield Schema.listSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[mutable.ArrayBuffer[u]] => + for (itemSchema <- interpret[u]) + yield Schema.arrayBufferSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[mutable.Buffer[u]] => + for (itemSchema <- interpret[u]) + yield Schema.bufferSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[mutable.Set[u]] => + for (itemSchema <- interpret[u]) + yield Schema.mutableSetSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[Set[u]] => + for (itemSchema <- interpret[u]) + yield Schema.setSchema(using itemSchema).asInstanceOf[Schema[T]] + // TODO SortedSet[T] + case '[mutable.ListBuffer[u]] => + for (itemSchema <- interpret[u]) + yield Schema.listBufferSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[Vector[u]] => + for (itemSchema <- interpret[u]) + yield Schema.vectorSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[mutable.Map[k, v]] => + for { + keySchema <- interpret[k] + valueSchema <- interpret[v] + } yield Schema.mutableMapSchema(using keySchema, valueSchema).asInstanceOf[Schema[T]] + case '[Map[k, v]] => + for { + keySchema <- interpret[k] + valueSchema <- interpret[v] + } yield Schema.mapSchema(using keySchema, valueSchema).asInstanceOf[Schema[T]] + case '[Seq[u]] => + for (itemSchema <- interpret[u]) + yield Schema.seqSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[TraversableOnce[u]] => + for (itemSchema <- interpret[u]) + yield Schema.traversableOnceSchema(using itemSchema).asInstanceOf[Schema[T]] + case '[Iterable[u]] => + for (itemSchema <- interpret[u]) + yield Schema.iterableSchema(using itemSchema).asInstanceOf[Schema[T]] + case _ => + import quotes.reflect._ + val tp = TypeRepr.of[T] + val tpSymbol: Symbol = tp.typeSymbol + + // if case class iterate and recurse, else sorry + if tp <:< TypeRepr.of[Product] && tpSymbol.caseFields.nonEmpty then { + val schemasOpt: List[Option[(String, Schema[Any])]] = tpSymbol.caseFields.map { (f: Symbol) => + assert(f.isValDef) + val fieldName = f.name + val fieldType: TypeRepr = tp.memberType(f) + fieldType.asType match { + // mattern match to create a bind <3 + case '[u] => interpret[u].asInstanceOf[Option[Schema[Any]]].map(s => (fieldName, s)) + } + } + sequence(schemasOpt).map(schemas => Record(schemas.toArray, null, null)) + } else if tpSymbol.flags.is(Flags.JavaDefined) && scala.util.Try(checkGetterAndSetters[T]).isSuccess then { + val schemasOpt = tpSymbol.declaredMethods.collect { + case s if s.name.toString.startsWith("get") && s.isDefDef=> + val fieldName: String = s.name.toString.drop(3) + val fieldType: TypeRepr = tp.memberType(s) + fieldType match { + case MethodType(_, _, returnTpt) => + returnTpt.asType match { + case '[u] => interpret[u].asInstanceOf[Option[Schema[Any]]].map(s => (fieldName, s)) + } + } + } + // RawRecord is used for JavaBeans, not Record + sequence(schemasOpt).map(schemas => Record(schemas.toArray, null, null)) + } else None + } } trait ToMacro { @@ -36,7 +186,6 @@ trait ToMacro { * at compile time. * @see To#unsafe */ - // TODO: scala3 - inline def safe[I, O](inline si: Schema[I], inline so: Schema[O]): To[I, O] = - ??? + inline def safe[I, O](using inline iSchema: Schema[I], inline oSchema: Schema[O]): To[I, O] = + ${ ToMacro.safeImpl('iSchema, 'oSchema) } } diff --git a/scio-core/src/main/scala/com/spotify/scio/hash/MutableScalableBloomFilter.scala b/scio-core/src/main/scala/com/spotify/scio/hash/MutableScalableBloomFilter.scala index 38bcc67617..622f78447f 100644 --- a/scio-core/src/main/scala/com/spotify/scio/hash/MutableScalableBloomFilter.scala +++ b/scio-core/src/main/scala/com/spotify/scio/hash/MutableScalableBloomFilter.scala @@ -225,7 +225,7 @@ case class MutableScalableBloomFilter[T]( } def ++=(items: TraversableOnce[T]): MutableScalableBloomFilter[T] = { - items.foreach(i => this += i) // no bulk insert for guava BFs + items.iterator.foreach(i => this += i) // no bulk insert for guava BFs this } } diff --git a/scio-core/src/main/scala/com/spotify/scio/schemas/instances/ScalaInstances.scala b/scio-core/src/main/scala/com/spotify/scio/schemas/instances/ScalaInstances.scala index b86bc18e28..1be46954f6 100644 --- a/scio-core/src/main/scala/com/spotify/scio/schemas/instances/ScalaInstances.scala +++ b/scio-core/src/main/scala/com/spotify/scio/schemas/instances/ScalaInstances.scala @@ -69,7 +69,7 @@ trait ScalaInstances { ArrayType(s, _.asJava, _.asScala.toList) implicit def traversableOnceSchema[T](implicit s: Schema[T]): Schema[TraversableOnce[T]] = - ArrayType(s, _.toList.asJava, _.asScala.toList) + ArrayType(s, _.iterator.to(List).asJava, _.asScala.toList) implicit def iterableSchema[T](implicit s: Schema[T]): Schema[Iterable[T]] = ArrayType(s, _.toList.asJava, _.asScala.toList) diff --git a/scio-core/src/main/scala/com/spotify/scio/transforms/ParallelismDoFns.scala b/scio-core/src/main/scala/com/spotify/scio/transforms/ParallelismDoFns.scala index eb212ed164..13c0507845 100644 --- a/scio-core/src/main/scala/com/spotify/scio/transforms/ParallelismDoFns.scala +++ b/scio-core/src/main/scala/com/spotify/scio/transforms/ParallelismDoFns.scala @@ -50,7 +50,7 @@ class ParallelFlatMapFn[T, U](parallelism: Int)(f: T => TraversableOnce[U]) extends ParallelLimitedFn[T, U](parallelism: Int) { val g: T => TraversableOnce[U] = ClosureCleaner.clean(f) // defeat closure def parallelProcessElement(c: DoFn[T, U]#ProcessContext): Unit = { - val i = g(c.element()).toIterator + val i = g(c.element()).iterator while (i.hasNext) c.output(i.next()) } } diff --git a/scio-core/src/test/scala/com/spotify/scio/IsJavaTest.scala b/scio-core/src/test/scala/com/spotify/scio/IsJavaTest.scala new file mode 100644 index 0000000000..8b0f53ac24 --- /dev/null +++ b/scio-core/src/test/scala/com/spotify/scio/IsJavaTest.scala @@ -0,0 +1,15 @@ +package com.spotify.scio + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +class IsJavaTest extends AnyFlatSpec with Matchers { + "IsJavaBean" should "succeed for a java bean" in { + println(IsJavaBean[JavaBeanA]) + } + + it should "not compile for a case class" in { + case class Foo(s: String, i: Int) + "IsJavaBean[Foo]" shouldNot compile + } +} diff --git a/scio-core/src/test/scala/com/spotify/scio/JavaBeanA.java b/scio-core/src/test/scala/com/spotify/scio/JavaBeanA.java new file mode 100644 index 0000000000..4976dc920b --- /dev/null +++ b/scio-core/src/test/scala/com/spotify/scio/JavaBeanA.java @@ -0,0 +1,28 @@ +package com.spotify.scio; + +class JavaBeanA implements java.io.Serializable { + private String firstName = null; + private String lastName = null; + private int age = 0; + + public JavaBeanA() { + } + public String getFirstName(){ + return firstName; + } + public String getLastName(){ + return lastName; + } + public int getAge(){ + return age; + } + public void setFirstName(String firstName){ + this.firstName = firstName; + } + public void setLastName(String lastName){ + this.lastName = lastName; + } + public void setAge(int age){ + this.age = age; + } +} \ No newline at end of file diff --git a/scio-core/src/test/scala/com/spotify/scio/JavaBeanB.java b/scio-core/src/test/scala/com/spotify/scio/JavaBeanB.java new file mode 100644 index 0000000000..32da1cae96 --- /dev/null +++ b/scio-core/src/test/scala/com/spotify/scio/JavaBeanB.java @@ -0,0 +1,28 @@ +package com.spotify.scio; + +class JavaBeanB implements java.io.Serializable { + private String name = null; + private String uuid = null; + private int money = 0; + + public JavaBeanB() { + } + public String getName(){ + return name; + } + public String getUuid(){ + return uuid; + } + public int getMoney(){ + return money; + } + public void setName(String name){ + this.name = name; + } + public void setUuid(String uuid){ + this.uuid = uuid; + } + public void setMoney(int money){ + this.money = money; + } +} \ No newline at end of file diff --git a/scio-core/src/test/scala/com/spotify/scio/JavaBeanC.java b/scio-core/src/test/scala/com/spotify/scio/JavaBeanC.java new file mode 100644 index 0000000000..3c9949f4d6 --- /dev/null +++ b/scio-core/src/test/scala/com/spotify/scio/JavaBeanC.java @@ -0,0 +1,28 @@ +package com.spotify.scio; + +class JavaBeanC implements java.io.Serializable { + private String firstName = null; + private String lastName = null; + private int age = 0; + + public JavaBeanC() { + } + public String getFirstName(){ + return firstName; + } + public String getLastName(){ + return lastName; + } + public int getAge(){ + return age; + } + public void setFirstName(String firstName){ + this.firstName = firstName; + } + public void setLastName(String lastName){ + this.lastName = lastName; + } + public void setAge(int age){ + this.age = age; + } +} \ No newline at end of file diff --git a/scio-core/src/test/scala/com/spotify/scio/ToSafeSuite.scala b/scio-core/src/test/scala/com/spotify/scio/ToSafeSuite.scala new file mode 100644 index 0000000000..10d1d13a55 --- /dev/null +++ b/scio-core/src/test/scala/com/spotify/scio/ToSafeSuite.scala @@ -0,0 +1,64 @@ +package com.spotify.scio + +import com.spotify.scio.schemas.To + +import org.scalatest.flatspec.AnyFlatSpec +import org.scalatest.matchers.should.Matchers + +case class JavaListInt(l: java.util.List[java.lang.Integer]) +case class JavaListString(l: java.util.List[java.lang.String]) +case class ListInt(l: List[Int]) +case class JavaSource(b: java.lang.Boolean) +case class Source(b: Boolean) +case class Dest(b: Boolean) +case class Mistake(b: Int) +case class Mistake2(c: Boolean) + +case class Sources(name: String, links: List[Array[Byte]]) +case class Destinations(name: String, links: List[Array[Byte]]) +case class DestinationsWrong(name: String, links: List[Array[Int]]) + +class ToSafeTest extends AnyFlatSpec with Matchers { + "To.safe" should "generate a conversion on compatible flat case class schemas" in { + To.safe[Source, Dest] + } + + it should "generate a conversion between java.lang.Boolean and Boolean" in { + To.safe[JavaSource, Source] + To.safe[Source, JavaSource] + } + + it should "generate a conversion between java.util.List[java.lang.Integer] and List[Int]" in { + To.safe[JavaListInt, ListInt] + To.safe[ListInt, JavaListInt] + } + + it should "fail on incompatible Java types" in { + "To.safe[JavaListString, JavaListInt]" shouldNot compile + "To.safe[JavaListString, ListInt]" shouldNot compile + } + + it should "fail on incompatible flat case class schemas" in { + "To.safe[Source, Mistake2]" shouldNot compile + "To.safe[Source, Mistake]" shouldNot compile + } + + it should "generate a conversion on compatible nested case class schemas" in { + To.safe[Sources, Destinations] + } + + it should "fail on incompatible nested case class schemas" in { + "To.safe[Sources, DestinationsWrong]" shouldNot compile + } + + it should "work with java beans" in { + "To.safe[JavaBeanA, JavaBeanB]" shouldNot compile + "To.safe[JavaBeanB, JavaBeanA]" shouldNot compile + + "To.safe[JavaBeanB, JavaBeanC]" shouldNot compile + "To.safe[JavaBeanC, JavaBeanB]" shouldNot compile + + To.safe[JavaBeanA, JavaBeanC] + To.safe[JavaBeanC, JavaBeanA] + } +} diff --git a/scio-macros/src/main/scala-3/com/spotify/scio/IsJava.scala b/scio-macros/src/main/scala-3/com/spotify/scio/IsJava.scala index 6b475f7627..bc86b1efa4 100644 --- a/scio-macros/src/main/scala-3/com/spotify/scio/IsJava.scala +++ b/scio-macros/src/main/scala-3/com/spotify/scio/IsJava.scala @@ -17,6 +17,7 @@ package com.spotify.scio +import scala.util.{Try => STry} import scala.compiletime._ import scala.deriving._ import scala.quoted._ @@ -26,30 +27,30 @@ sealed trait IsJavaBean[T] object IsJavaBean { - private def checkGetterAndSetters(using q: Quotes)(sym: q.reflect.Symbol): Unit = { - import q.reflect._ - val methods: List[Symbol] = sym.declaredMethods + private[scio] def checkGetterAndSetters[T: scala.quoted.Type](using Quotes): Unit = { + import quotes.reflect._ + val methods: List[Symbol] = TypeRepr.of[T].typeSymbol.declaredMethods - val getters = + val getters: List[(String, Symbol)] = methods.collect { - case s if s.name.toString.startsWith("get") => - (s.name.toString.drop(3), s.tree.asInstanceOf[DefDef]) + case s if s.name.toString.startsWith("get") && s.isDefDef => + (s.name.toString.drop(3), s) } - val setters = + val setters: Map[String, Symbol] = methods.collect { - case s if s.name.toString.startsWith("set") => - (s.name.toString.drop(3), s.tree.asInstanceOf[DefDef]) + case s if s.name.toString.startsWith("set") && s.isDefDef => + (s.name.toString.drop(3), s) }.toMap - if(getters.isEmpty) { + if (getters.isEmpty) then { val mess = - s"""Class ${sym.name} has not getter""" + s"""Class ${TypeRepr.of[T].typeSymbol.name} has not getter""" report.throwError(mess) } - getters.foreach { case (name, info) => - val setter: DefDef = + getters.foreach { case (name, getter) => + val setter: Symbol = setters // Map[String, DefDef] .get(name) .getOrElse { @@ -59,30 +60,32 @@ object IsJavaBean { report.throwError(mess) } - val resType: TypeRepr = info.returnTpt.tpe - setter.paramss.head match { - case TypeParamClause(params: List[TypeDef]) => report.throwError(s"JavaBean setter for field $name has type parameters") - case TermParamClause(head :: _) => - val tpe = head.tpt.tpe - if (resType != tpe) { + val getterType: TypeRepr = TypeRepr.of[T].memberType(getter) + val setterType: TypeRepr = TypeRepr.of[T].memberType(setter) + (getterType, setterType) match { + // MethodType(paramNames, paramTypes, returnType) + case (MethodType(_, Nil, getReturnType), MethodType(_, setReturnType :: Nil, _)) => + if getReturnType != setReturnType then { val mess = s"""JavaBean contained setter for field $name that had a mismatching type. - | found: $tpe - | expected: $resType""".stripMargin + | found: $setReturnType + | expected: $getReturnType""".stripMargin report.throwError(mess) } } } } + private def isJavaBeanImpl[T](using Quotes, Type[T]): Expr[IsJavaBean[T]] = { import quotes.reflect._ - val sym = TypeRepr.of[T].typeSymbol - if sym.flags.is(Flags.JavaDefined) then checkGetterAndSetters(sym) - '{new IsJavaBean[T]{}} + if TypeRepr.of[T].typeSymbol.flags.is(Flags.JavaDefined) && STry(checkGetterAndSetters[T]).isSuccess then + '{new IsJavaBean[T]{}} + else + report.throwError("Not a Java Bean") } - inline given isJavaBean[T]: IsJavaBean[T] = { + transparent inline given [T]: IsJavaBean[T] = { ${ isJavaBeanImpl[T] } }