diff --git a/.github/workflows/core-ci.yml b/.github/workflows/core-ci.yml index 3ef8368..a57dbb3 100644 --- a/.github/workflows/core-ci.yml +++ b/.github/workflows/core-ci.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"] + spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4", "3.4.3", "3.5.3"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/.github/workflows/unsafe-ci.yml b/.github/workflows/unsafe-ci.yml index 9b25c31..fee5eaa 100644 --- a/.github/workflows/unsafe-ci.yml +++ b/.github/workflows/unsafe-ci.yml @@ -11,7 +11,7 @@ jobs: strategy: fail-fast: false matrix: - spark: ["3.2.4", "3.3.4"] + spark: ["3.2.4", "3.3.4", "3.4.3", "3.5.3"] runs-on: ubuntu-latest steps: - uses: actions/checkout@v1 diff --git a/build.sbt b/build.sbt index 2bf0b03..3dd0b9e 100644 --- a/build.sbt +++ b/build.sbt @@ -1,3 +1,5 @@ +import scala.language.postfixOps + Compile / scalafmtOnCompile := true organization := "com.github.mrpowers" @@ -48,6 +50,18 @@ lazy val unsafe = (project in file("unsafe")) .settings( commonSettings, name := "unsafe", + Compile / unmanagedSourceDirectories ++= { + sparkVersion match { + case versionRegex(mayor, minor, _) => + (Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get + } + }, + Test / unmanagedSourceDirectories ++= { + sparkVersion match { + case versionRegex(mayor, minor, _) => + (Compile / sourceDirectory).value ** s"*spark_*$mayor.$minor*" / "scala" get + } + }, ) testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework") diff --git a/unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala b/unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala similarity index 86% rename from unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala rename to unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala index ce3aed5..4141478 100644 --- a/unsafe/src/main/scala/org/apache/spark/sql/catalyst/expressions/RandGamma.scala +++ b/unsafe/src/main/spark_3.2_3.3/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala @@ -2,16 +2,13 @@ package org.apache.spark.sql.catalyst.expressions import org.apache.commons.math3.distribution.GammaDistribution import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} import org.apache.spark.sql.types._ -import org.apache.spark.util.Utils import org.apache.spark.util.random.XORShiftRandomAdapted -import scala.util.{Success, Try} - case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) extends TernaryExpression with ExpectsInputTypes @@ -43,7 +40,7 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal) } - def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) + def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) @@ -87,10 +84,4 @@ case class RandGamma(child: Expression, shape: Expression, scale: Expression, hi object RandGamma { def apply(seed: Long, shape: Double, scale: Double): RandGamma = RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) - - def defaultSeedExpression: Expression = - Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match { - case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression] - case _ => Literal(Utils.random.nextLong(), LongType) - } } diff --git a/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala b/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala new file mode 100644 index 0000000..ab2df0e --- /dev/null +++ b/unsafe/src/main/spark_3.4_3.5/scala/org.apache.spark/sql/catalyst/expressions/RandGamma.scala @@ -0,0 +1,89 @@ +package org.apache.spark.sql.catalyst.expressions + +import org.apache.commons.math3.distribution.GammaDistribution +import org.apache.spark.sql.catalyst.analysis.UnresolvedSeed +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral +import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper +import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode} +import org.apache.spark.sql.types._ +import org.apache.spark.util.random.XORShiftRandomAdapted + +case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false) + extends TernaryExpression + with ExpectsInputTypes + with Nondeterministic + with ExpressionWithRandomSeed { + + def seedExpression: Expression = child + + @transient protected lazy val seed: Long = seedExpression match { + case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int] + case e if e.dataType == LongType => e.eval().asInstanceOf[Long] + } + + @transient protected lazy val shapeVal: Double = shape.dataType match { + case IntegerType => shape.eval().asInstanceOf[Int] + case LongType => shape.eval().asInstanceOf[Long] + case FloatType | DoubleType => shape.eval().asInstanceOf[Double] + } + + @transient protected lazy val scaleVal: Double = scale.dataType match { + case IntegerType => scale.eval().asInstanceOf[Int] + case LongType => scale.eval().asInstanceOf[Long] + case FloatType | DoubleType => scale.eval().asInstanceOf[Double] + } + + @transient private var distribution: GammaDistribution = _ + + override protected def initializeInternal(partitionIndex: Int): Unit = { + distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal) + } + + def this() = this(UnresolvedSeed, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true) + + def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false) + + def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed) + + protected def evalInternal(input: InternalRow): Double = distribution.sample() + + def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = { + val distributionClassName = classOf[GammaDistribution].getName + val rngClassName = classOf[XORShiftRandomAdapted].getName + val disTerm = ctx.addMutableState(distributionClassName, "distribution") + ctx.addPartitionInitializationStatement( + s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);" + ) + ev.copy(code = code""" + final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral) + } + + def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed) + + override def flatArguments: Iterator[Any] = Iterator(child, shape, scale) + + override def prettyName: String = "rand_gamma" + + override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})" + + override def stateful: Boolean = true + + def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType) + + def dataType: DataType = DoubleType + + def first: Expression = child + + def second: Expression = shape + + def third: Expression = scale + + protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression = + copy(child = newFirst, shape = newSecond, scale = newThird) +} + +object RandGamma { + def apply(seed: Long, shape: Double, scale: Double): RandGamma = + RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType)) +}