diff --git a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala index d4322e5..ba8c027 100644 --- a/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala +++ b/unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala @@ -2,25 +2,110 @@ package org.apache.spark.sql.daria import org.apache.spark.sql.Column import org.apache.spark.sql.catalyst.expressions.{Expression, RandGamma} -import org.apache.spark.sql.functions.{lit, log, rand, when} +import org.apache.spark.sql.functions.{lit, log, when} +import org.apache.spark.sql.{functions => F} import org.apache.spark.util.Utils object functions { private def withExpr(expr: Expression): Column = Column(expr) - def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") - def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale) - def randGamma(): Column = randGamma(1.0, 1.0) + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Gamma distribution with the specified shape and scale parameters. + * + * @note The function is non-deterministic in general case. + */ + def rand_gamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random") - def randLaplace(seed: Long, mu: Double, beta: Double): Column = { + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Gamma distribution with the specified shape and scale parameters. + * + * @note The function is non-deterministic in general case. + */ + def rand_gamma(shape: Double, scale: Double): Column = rand_gamma(Utils.random.nextLong, shape, scale) + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Gamma distribution with default parameters (shape = 1.0, scale = 1.0). + * + * @return A column with i.i.d. samples from the default Gamma distribution. + * + * @note The function is non-deterministic in general case. + */ + def rand_gamma(): Column = rand_gamma(1.0, 1.0) + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`. + * + * @note The function is non-deterministic in general case. + */ + def rand_laplace(seed: Long, mu: Double, beta: Double): Column = { val mu_ = lit(mu) val beta_ = lit(beta) - val u = rand(seed) + val u = F.rand(seed) when(u < 0.5, mu_ + beta_ * log(lit(2) * u)) .otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u))) .alias("laplace_random") } - def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta) - def randLaplace(): Column = randLaplace(0.0, 1.0) + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`. + * + * @note The function is non-deterministic in general case. + */ + def rand_laplace(mu: Double, beta: Double): Column = rand_laplace(Utils.random.nextLong, mu, beta) + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples + * from the Laplace distribution with default parameters (mu = 0.0, beta = 1.0). + * + * @note The function is non-deterministic in general case. + */ + def rand_laplace(): Column = rand_laplace(0.0, 1.0) + + /** + * Generate a random column with independent and identically distributed (i.i.d.) samples + * uniformly distributed in [`min`, `max`). + * + * @note The function is non-deterministic in general case. + */ + def rand_range(seed: Long, min: Int, max: Int): Column = { + val min_ = lit(min) + val max_ = lit(max) + min_ + (max_ - min_) * F.rand(seed) + } + + /** + * Generate a random column with independent and identically distributed (i.i.d.) samples + * uniformly distributed in [`min`, `max`). + * + * @note The function is non-deterministic in general case. + */ + def rand_range(min: Int, max: Int): Column = { + rand_range(Utils.random.nextLong, min, max) + } + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples from + * the standard normal distribution with given `mean` and `variance`. + * + * @note The function is non-deterministic in general case. + */ + def randn(seed: Long, mean: Double, variance: Double): Column = { + val stddev = math.sqrt(variance) + F.randn(seed) * lit(stddev) + lit(mean) + } + + /** + * Generate a column with independent and identically distributed (i.i.d.) samples from + * the standard normal distribution with given `mean` and `variance`. + * + * @note The function is non-deterministic in general case. + */ + def randn(mean: Double, variance: Double): Column = { + randn(Utils.random.nextLong, mean, variance) + } } diff --git a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala index 5147c13..c026387 100644 --- a/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala +++ b/unsafe/src/test/scala/org/apache/spark/sql/daria/functionsTests.scala @@ -2,8 +2,7 @@ package org.apache.spark.sql.daria import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer} import org.apache.spark.sql.daria.functions._ -import org.apache.spark.sql.functions._ -import org.apache.spark.sql.functions.stddev +import org.apache.spark.sql.{functions => F} import utest._ object functionsTests extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper { @@ -11,11 +10,11 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar val tests = Tests { 'rand_gamma - { "has correct mean and standard deviation" - { - val sourceDF = spark.range(100000).select(randGamma(2.0, 2.0)) + val sourceDF = spark.range(100000).select(rand_gamma(2.0, 2.0)) val stats = sourceDF .agg( - mean("gamma_random").as("mean"), - stddev("gamma_random").as("stddev") + F.mean("gamma_random").as("mean"), + F.stddev("gamma_random").as("stddev") ) .collect()(0) @@ -31,11 +30,11 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar 'rand_laplace - { "has correct mean and standard deviation" - { - val sourceDF = spark.range(100000).select(randLaplace()) + val sourceDF = spark.range(100000).select(rand_laplace()) val stats = sourceDF .agg( - mean("laplace_random").as("mean"), - stddev("laplace_random").as("std_dev") + F.mean("laplace_random").as("mean"), + F.stddev("laplace_random").as("std_dev") ) .collect()(0) @@ -47,5 +46,45 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar assert(math.abs(laplaceStdDev - math.sqrt(2.0)) < 0.5) } } + + 'rand - { + "has correct min and max" - { + val min = 5 + val max = 10 + val sourceDF = spark.range(100000).select(rand_range(min, max).as("rand_min_max")) + val stats = sourceDF + .agg( + F.min("rand_min_max").as("min"), + F.min("rand_min_max").as("max") + ) + .collect()(0) + + val uniformMin = stats.getAs[Double]("min") + val uniformMax = stats.getAs[Double]("max") + + assert(uniformMin >= min) + assert(uniformMax <= max) + } + } + + 'randn - { + "has correct mean and variance" - { + val mean = 1 + val variance = 2 + val sourceDF = spark.range(100000).select(randn(mean, variance).as("rand_normal")) + val stats = sourceDF + .agg( + F.mean("rand_normal").as("mean"), + F.variance("rand_normal").as("variance") + ) + .collect()(0) + + val normalMean = stats.getAs[Double]("mean") + val normalVariance = stats.getAs[Double]("variance") + + assert(math.abs(normalMean - mean) <= 0.1) + assert(math.abs(normalVariance - variance) <= 0.1) + } + } } }