Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

run test on pull request. #160

Merged
merged 13 commits into from
Oct 5, 2024
Merged
Show file tree
Hide file tree
Changes from 9 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 22 additions & 0 deletions .github/workflows/core-ci.yml
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
name: core-ci

on:
push:
branches:
- main
pull_request:

jobs:
build:
strategy:
fail-fast: false
matrix:
spark: ["3.0.1", "3.1.3", "3.2.4", "3.3.4"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: olafurpg/setup-scala@v10
- name: Test
run: sbt -Dspark.testVersion=${{ matrix.spark }} +"project core" test
- name: Code Quality
run: sbt "project core" scalafmtCheckAll
10 changes: 5 additions & 5 deletions .github/workflows/ci.yml → .github/workflows/unsafe-ci.yml
Original file line number Diff line number Diff line change
@@ -1,22 +1,22 @@
name: ci
name: unsafe-ci

on:
push:
branches:
- main
pull_request:

jobs:
build:
strategy:
fail-fast: false
matrix:
scala: ["2.12.12"]
spark: ["3.0.1"]
spark: ["3.2.4", "3.3.4"]
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v1
- uses: olafurpg/setup-scala@v10
- name: Test
run: sbt -Dspark.testVersion=${{ matrix.spark }} ++${{ matrix.scala }} test
run: sbt -Dspark.testVersion=${{ matrix.spark }} +"project unsafe" test
- name: Code Quality
run: sbt scalafmtCheckAll
run: sbt "project unsafe" scalafmtCheckAll
2 changes: 1 addition & 1 deletion .scalafmt.conf
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
version = 2.6.3

lineEndings = preserve
align = more
maxColumn = 150
docstrings = JavaDoc
50 changes: 42 additions & 8 deletions build.sbt
Original file line number Diff line number Diff line change
Expand Up @@ -4,27 +4,61 @@ organization := "com.github.mrpowers"
name := "spark-daria"

version := "1.2.3"

crossScalaVersions := Seq("2.12.15", "2.13.8")
scalaVersion := "2.12.15"

val sparkVersion = "3.2.1"
val versionRegex = """^(.*)\.(.*)\.(.*)$""".r

val scala2_13 = "2.13.14"
val scala2_12 = "2.12.20"

val sparkVersion = System.getProperty("spark.testVersion", "3.3.4")
crossScalaVersions := {
sparkVersion match {
case versionRegex("3", m, _) if m.toInt >= 2 => Seq(scala2_12, scala2_13)
case versionRegex("3", _, _) => Seq(scala2_12)
}
}

scalaVersion := crossScalaVersions.value.head

lazy val commonSettings = Seq(
javaOptions ++= {
Seq("-Xms512M", "-Xmx2048M", "-Duser.timezone=GMT") ++ (if (System.getProperty("java.version").startsWith("1.8.0"))
Seq("-XX:+CMSClassUnloadingEnabled")
else Seq.empty)
},
libraryDependencies ++= Seq(
"org.apache.spark" %% "spark-sql" % sparkVersion % "provided",
"org.apache.spark" %% "spark-mllib" % sparkVersion % "provided",
"com.lihaoyi" %% "utest" % "0.7.11" % "test",
"com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
),
)

lazy val core = (project in file("core"))
.settings(
commonSettings,
name := "core",
)

lazy val unsafe = (project in file("unsafe"))
.settings(
commonSettings,
name := "unsafe",
)

libraryDependencies += "org.apache.spark" %% "spark-sql" % sparkVersion % "provided"
libraryDependencies += "org.apache.spark" %% "spark-mllib" % sparkVersion % "provided"
libraryDependencies += "com.github.mrpowers" %% "spark-fast-tests" % "1.1.0" % "test"
libraryDependencies += "com.lihaoyi" %% "utest" % "0.7.11" % "test"
libraryDependencies += "com.lihaoyi" %% "os-lib" % "0.8.0" % "test"
testFrameworks += new TestFramework("com.github.mrpowers.spark.daria.CustomFramework")

credentials += Credentials(Path.userHome / ".sbt" / "sonatype_credentials")

Test / fork := true

javaOptions ++= Seq("-Xms512M", "-Xmx2048M", "-XX:+CMSClassUnloadingEnabled", "-Duser.timezone=GMT")

licenses := Seq("MIT" -> url("http://opensource.org/licenses/MIT"))

homepage := Some(url("https://github.com/MrPowers/spark-daria"))

developers ++= List(
Developer("MrPowers", "Matthew Powers", "@MrPowers", url("https://github.com/MrPowers"))
)
Expand Down
File renamed without changes.
File renamed without changes.
File renamed without changes.
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@ import org.apache.spark.sql.SparkSession
trait SparkSessionTestWrapper {

lazy val spark: SparkSession = {
SparkSession
val session = SparkSession
.builder()
.master("local")
.appName("spark session")
Expand All @@ -14,6 +14,8 @@ trait SparkSessionTestWrapper {
"1"
)
.getOrCreate()
session.sparkContext.setLogLevel("ERROR")
session
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -1047,22 +1047,6 @@ object TransformationsTest extends TestSuite with DataFrameComparer with ColumnC
}

'withParquetCompatibleColumnNames - {
"blows up if the column name is invalid for Parquet" - {
val df = spark
.createDF(
List(
("pablo")
),
List(
("Column That {Will} Break\t;", StringType, true)
)
)
val path = new java.io.File("./tmp/blowup/example").getCanonicalPath
val e = intercept[org.apache.spark.sql.AnalysisException] {
df.write.parquet(path)
}
}

"converts column names to be Parquet compatible" - {
val actualDF = spark
.createDF(
Expand Down

This file was deleted.

Original file line number Diff line number Diff line change
@@ -0,0 +1,95 @@
package org.apache.spark.sql.catalyst.expressions

import org.apache.commons.math3.distribution.GammaDistribution
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.expressions.RandGamma.defaultSeedExpression
import org.apache.spark.sql.catalyst.expressions.codegen.FalseLiteral
import org.apache.spark.sql.catalyst.expressions.codegen.Block.BlockHelper
import org.apache.spark.sql.catalyst.expressions.codegen.{CodeGenerator, CodegenContext, ExprCode}
import org.apache.spark.sql.types._
import org.apache.spark.util.Utils
import org.apache.spark.util.random.XORShiftRandomAdapted

import scala.util.{Success, Try}

case class RandGamma(child: Expression, shape: Expression, scale: Expression, hideSeed: Boolean = false)
extends TernaryExpression
with ExpectsInputTypes
with Stateful
with ExpressionWithRandomSeed {

def seedExpression: Expression = child

@transient protected lazy val seed: Long = seedExpression match {
case e if e.dataType == IntegerType => e.eval().asInstanceOf[Int]
case e if e.dataType == LongType => e.eval().asInstanceOf[Long]
}

@transient protected lazy val shapeVal: Double = shape.dataType match {
case IntegerType => shape.eval().asInstanceOf[Int]
case LongType => shape.eval().asInstanceOf[Long]
case FloatType | DoubleType => shape.eval().asInstanceOf[Double]
}

@transient protected lazy val scaleVal: Double = scale.dataType match {
case IntegerType => scale.eval().asInstanceOf[Int]
case LongType => scale.eval().asInstanceOf[Long]
case FloatType | DoubleType => scale.eval().asInstanceOf[Double]
}

@transient private var distribution: GammaDistribution = _

protected def initializeInternal(partitionIndex: Int): Unit = {
distribution = new GammaDistribution(new XORShiftRandomAdapted(seed + partitionIndex), shapeVal, scaleVal)
}

def this() = this(defaultSeedExpression, Literal(1.0, DoubleType), Literal(1.0, DoubleType), true)

def this(child: Expression, shape: Expression, scale: Expression) = this(child, shape, scale, false)

def withNewSeed(seed: Long): RandGamma = RandGamma(Literal(seed, LongType), shape, scale, hideSeed)

protected def evalInternal(input: InternalRow): Double = distribution.sample()

def doGenCode(ctx: CodegenContext, ev: ExprCode): ExprCode = {
val distributionClassName = classOf[GammaDistribution].getName
val rngClassName = classOf[XORShiftRandomAdapted].getName
val disTerm = ctx.addMutableState(distributionClassName, "distribution")
ctx.addPartitionInitializationStatement(
s"$disTerm = new $distributionClassName(new $rngClassName(${seed}L + partitionIndex), $shapeVal, $scaleVal);"
)
ev.copy(code = code"""
final ${CodeGenerator.javaType(dataType)} ${ev.value} = $disTerm.sample();""", isNull = FalseLiteral)
}

def freshCopy(): RandGamma = RandGamma(child, shape, scale, hideSeed)

override def flatArguments: Iterator[Any] = Iterator(child, shape, scale)

override def prettyName: String = "rand_gamma"

override def sql: String = s"rand_gamma(${if (hideSeed) "" else s"${child.sql}, ${shape.sql}, ${scale.sql}"})"

def inputTypes: Seq[AbstractDataType] = Seq(LongType, DoubleType, DoubleType)

def dataType: DataType = DoubleType

def first: Expression = child

def second: Expression = shape

def third: Expression = scale

protected def withNewChildrenInternal(newFirst: Expression, newSecond: Expression, newThird: Expression): Expression =
copy(child = newFirst, shape = newSecond, scale = newThird)
}

object RandGamma {
def apply(seed: Long, shape: Double, scale: Double): RandGamma =
RandGamma(Literal(seed, LongType), Literal(shape, DoubleType), Literal(scale, DoubleType))

def defaultSeedExpression: Expression = Try(Class.forName("org.apache.spark.sql.catalyst.analysis.UnresolvedSeed")) match {
case Success(clazz) => clazz.getConstructor().newInstance().asInstanceOf[Expression]
case _ => Literal(Utils.random.nextLong(), LongType)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,18 +9,18 @@ object functions {
private def withExpr(expr: Expression): Column = Column(expr)

def randGamma(seed: Long, shape: Double, scale: Double): Column = withExpr(RandGamma(seed, shape, scale)).alias("gamma_random")
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)
def randGamma(shape: Double, scale: Double): Column = randGamma(Utils.random.nextLong, shape, scale)
def randGamma(): Column = randGamma(1.0, 1.0)

def randLaplace(seed: Long, mu: Double, beta: Double): Column = {
val mu_ = lit(mu)
val mu_ = lit(mu)
val beta_ = lit(beta)
val u = rand(seed)
val u = rand(seed)
when(u < 0.5, mu_ + beta_ * log(lit(2) * u))
.otherwise(mu_ - beta_ * log(lit(2) * (lit(1) - u)))
.alias("laplace_random")
}

def randLaplace(mu: Double, beta: Double): Column = randLaplace(Utils.random.nextLong, mu, beta)
def randLaplace(): Column = randLaplace(0.0, 1.0)
def randLaplace(): Column = randLaplace(0.0, 1.0)
}
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
nextSeed ^= (nextSeed >>> 35)
nextSeed ^= (nextSeed << 4)
seed = nextSeed
(nextSeed & ((1L << bits) -1)).asInstanceOf[Int]
(nextSeed & ((1L << bits) - 1)).asInstanceOf[Int]
}

override def setSeed(s: Long): Unit = {
Expand All @@ -29,4 +29,3 @@ class XORShiftRandomAdapted(init: Long) extends java.util.Random(init: Long) wit
this.seed = XORShiftRandom.hashSeed(RandomGeneratorFactory.convertToLong(seed))
}
}

Loading
Loading