Skip to content

Commit

Permalink
- Add rand_range with min and max + randn with mean and variance
Browse files Browse the repository at this point in the history
- Add rand_range with min and max
- Add randn with mean and variance

Formatting

Revert unrelated changes
  • Loading branch information
zeotuan committed Oct 9, 2024
1 parent 617e062 commit eef06ec
Show file tree
Hide file tree
Showing 2 changed files with 140 additions and 16 deletions.
101 changes: 93 additions & 8 deletions unsafe/src/main/scala/org/apache/spark/sql/daria/functions.scala
Original file line number Diff line number Diff line change
Expand Up @@ -2,25 +2,110 @@ package org.apache.spark.sql.daria

import org.apache.spark.sql.Column
import org.apache.spark.sql.catalyst.expressions.{Expression, RandGamma}
import org.apache.spark.sql.functions.{lit, log, rand, when}
import org.apache.spark.sql.functions.{lit, log, when}
import org.apache.spark.sql.{functions => F}
import org.apache.spark.util.Utils

object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with the specified shape and scale parameters.
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with the specified shape and scale parameters.
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(shape: Double, scale: Double): Column = rand_gamma(Utils.random.nextLong, shape, scale)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Gamma distribution with default parameters (shape = 1.0, scale = 1.0).
*
* @return A column with i.i.d. samples from the default Gamma distribution.
*
* @note The function is non-deterministic in general case.
*/
def rand_gamma(): Column = rand_gamma(1.0, 1.0)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = rand(seed)
val u = F.rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
}

def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta)
def randLaplace(): Column = randLaplace(0.0, 1.0)
/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with the specified location parameter `mu` and scale parameter `beta`.
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(mu: Double, beta: Double): Column = rand_laplace(Utils.random.nextLong, mu, beta)

/**
* Generate a column with independent and identically distributed (i.i.d.) samples
* from the Laplace distribution with default parameters (mu = 0.0, beta = 1.0).
*
* @note The function is non-deterministic in general case.
*/
def rand_laplace(): Column = rand_laplace(0.0, 1.0)

/**
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [`min`, `max`).
*
* @note The function is non-deterministic in general case.
*/
def rand_range(seed: Long, min: Int, max: Int): Column = {
val min_ = lit(min)
val max_ = lit(max)
min_ + (max_ - min_) * F.rand(seed)
}

/**
* Generate a random column with independent and identically distributed (i.i.d.) samples
* uniformly distributed in [`min`, `max`).
*
* @note The function is non-deterministic in general case.
*/
def rand_range(min: Int, max: Int): Column = {
rand_range(Utils.random.nextLong, min, max)
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution with given `mean` and `variance`.
*
* @note The function is non-deterministic in general case.
*/
def randn(seed: Long, mean: Double, variance: Double): Column = {
val stddev = math.sqrt(variance)
F.randn(seed) * lit(stddev) + lit(mean)
}

/**
* Generate a column with independent and identically distributed (i.i.d.) samples from
* the standard normal distribution with given `mean` and `variance`.
*
* @note The function is non-deterministic in general case.
*/
def randn(mean: Double, variance: Double): Column = {
randn(Utils.random.nextLong, mean, variance)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,20 +2,19 @@ package org.apache.spark.sql.daria

import com.github.mrpowers.spark.fast.tests.{ColumnComparer, DataFrameComparer}
import org.apache.spark.sql.daria.functions._
import org.apache.spark.sql.functions._
import org.apache.spark.sql.functions.stddev
import org.apache.spark.sql.{functions => F}
import utest._

object functionsTests extends TestSuite with DataFrameComparer with ColumnComparer with SparkSessionTestWrapper {

val tests = Tests {
'rand_gamma - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randGamma(2.0, 2.0))
val sourceDF = spark.range(100000).select(rand_gamma(2.0, 2.0))
val stats = sourceDF
.agg(
mean("gamma_random").as("mean"),
stddev("gamma_random").as("stddev")
F.mean("gamma_random").as("mean"),
F.stddev("gamma_random").as("stddev")
)
.collect()(0)

Expand All @@ -31,11 +30,11 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar

'rand_laplace - {
"has correct mean and standard deviation" - {
val sourceDF = spark.range(100000).select(randLaplace())
val sourceDF = spark.range(100000).select(rand_laplace())
val stats = sourceDF
.agg(
mean("laplace_random").as("mean"),
stddev("laplace_random").as("std_dev")
F.mean("laplace_random").as("mean"),
F.stddev("laplace_random").as("std_dev")
)
.collect()(0)

Expand All @@ -47,5 +46,45 @@ object functionsTests extends TestSuite with DataFrameComparer with ColumnCompar
assert(math.abs(laplaceStdDev - math.sqrt(2.0)) < 0.5)
}
}

'rand - {
"has correct min and max" - {
val min = 5
val max = 10
val sourceDF = spark.range(100000).select(rand_range(min, max).as("rand_min_max"))
val stats = sourceDF
.agg(
F.min("rand_min_max").as("min"),
F.min("rand_min_max").as("max")
)
.collect()(0)

val uniformMin = stats.getAs[Double]("min")
val uniformMax = stats.getAs[Double]("max")

assert(uniformMin >= min)
assert(uniformMax <= max)
}
}

'randn - {
"has correct mean and variance" - {
val mean = 1
val variance = 2
val sourceDF = spark.range(100000).select(randn(mean, variance).as("rand_normal"))
val stats = sourceDF
.agg(
F.mean("rand_normal").as("mean"),
F.variance("rand_normal").as("variance")
)
.collect()(0)

val normalMean = stats.getAs[Double]("mean")
val normalVariance = stats.getAs[Double]("variance")

assert(math.abs(normalMean - mean) <= 0.1)
assert(math.abs(normalVariance - variance) <= 0.1)
}
}
}
}

0 comments on commit eef06ec

Please sign in to comment.