Skip to content

Commit

Permalink
Merge pull request #133 from zeotuan/AssertApproximateSmallDf
Browse files Browse the repository at this point in the history
Add AssertApproximateSmallDataFrameEquality
  • Loading branch information
SemyonSinchenko authored Aug 25, 2024
2 parents a48bf28 + f27d405 commit c933a8a
Show file tree
Hide file tree
Showing 7 changed files with 572 additions and 69 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.spark.sql.DataFrame
import org.apache.spark.sql.{DataFrame, Row}

trait DataFrameComparer extends DatasetComparer {

Expand Down Expand Up @@ -48,4 +48,49 @@ trait DataFrameComparer extends DatasetComparer {
)
}

/**
* Raises an error unless `actualDF` and `expectedDF` are equal
*/
def assertApproximateSmallDataFrameEquality(
actualDF: DataFrame,
expectedDF: DataFrame,
precision: Double,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
orderedComparison: Boolean = true,
ignoreColumnOrder: Boolean = false
): Unit = {
assertSmallDatasetEquality[Row](
actualDF,
expectedDF,
ignoreNullable,
ignoreColumnNames,
orderedComparison,
ignoreColumnOrder,
equals = RowComparer.areRowsEqual(_, _, precision)
)
}

/**
* Raises an error unless `actualDF` and `expectedDF` are equal
*/
def assertApproximateLargeDataFrameEquality(
actualDF: DataFrame,
expectedDF: DataFrame,
precision: Double,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
orderedComparison: Boolean = true,
ignoreColumnOrder: Boolean = false
): Unit = {
assertLargeDatasetEquality[Row](
actualDF,
expectedDF,
equals = RowComparer.areRowsEqual(_, _, precision),
ignoreNullable,
ignoreColumnNames,
orderedComparison,
ignoreColumnOrder
)
}
}
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.DatasetComparer.maxUnequalRowsToShow
import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.{DataFrame, Dataset, Row}
import org.apache.spark.sql.functions._
Expand Down Expand Up @@ -50,24 +51,25 @@ Expected DataFrame Row Count: '$expectedCount'
ignoreColumnNames: Boolean = false,
orderedComparison: Boolean = true,
ignoreColumnOrder: Boolean = false,
truncate: Int = 500
truncate: Int = 500,
equals: (T, T) => Boolean = (o1: T, o2: T) => o1.equals(o2)
): Unit = {
SchemaComparer.assertSchemaEqual(actualDS, expectedDS, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
val actual = if (ignoreColumnOrder) orderColumns(actualDS, expectedDS) else actualDS
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate)
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}

def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, truncate: Int): Unit = {
def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, truncate: Int, equals: (T, T) => Boolean): Unit = {
if (orderedComparison)
assertSmallDatasetContentEquality(actualDS, expectedDS, truncate)
assertSmallDatasetContentEquality(actualDS, expectedDS, truncate, equals)
else
assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate)
assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals)
}

def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int): Unit = {
def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
val a = actualDS.collect()
val e = expectedDS.collect()
if (!a.sameElements(e)) {
if (!a.toSeq.approximateSameElements(e, equals)) {
throw DatasetContentMismatch(betterContentMismatchMessage(a, e, truncate))
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,75 +1,48 @@
package com.github.mrpowers.spark.fast.tests

import org.apache.commons.math3.util.Precision
import org.apache.spark.sql.Row

import java.sql.Timestamp
import scala.math.abs

object RowComparer {

/** Approximate equality, based on equals from [[Row]] */
def areRowsEqual(r1: Row, r2: Row, tol: Double): Boolean = {
def areRowsEqual(r1: Row, r2: Row, tol: Double = 0): Boolean = {
if (tol == 0) {
return r1 == r2
}
if (r1.length != r2.length) {
return false
} else {
(0 until r1.length).foreach(idx => {
if (r1.isNullAt(idx) != r2.isNullAt(idx)) {
return false
}
for (i <- 0 until r1.length) {
if (r1.isNullAt(i) != r2.isNullAt(i)) {
return false
}
if (!r1.isNullAt(i)) {
val o1 = r1.get(i)
val o2 = r2.get(i)
val valid = o1 match {
case b1: Array[Byte] =>
o2.isInstanceOf[Array[Byte]] && java.util.Arrays.equals(b1, o2.asInstanceOf[Array[Byte]])
case f1: Float if o2.isInstanceOf[Float] => Precision.equalsIncludingNaN(f1, o2.asInstanceOf[Float], tol)
case d1: Double if o2.isInstanceOf[Double] => Precision.equalsIncludingNaN(d1, o2.asInstanceOf[Double], tol)
case bd1: java.math.BigDecimal if o2.isInstanceOf[java.math.BigDecimal] =>
bd1.subtract(o2.asInstanceOf[java.math.BigDecimal]).abs().compareTo(new java.math.BigDecimal(tol)) == -1
case f1: Number if o2.isInstanceOf[Number] =>
val bd1 = new java.math.BigDecimal(f1.toString)
val bd2 = new java.math.BigDecimal(o2.toString)
bd1.subtract(bd2).abs().compareTo(new java.math.BigDecimal(tol)) == -1
case t1: java.sql.Timestamp => abs(t1.getTime - o2.asInstanceOf[java.sql.Timestamp].getTime) > tol
case t1: java.time.Instant => abs(t1.toEpochMilli - o2.asInstanceOf[java.time.Instant].toEpochMilli) > tol
case rr1: Row if o2.isInstanceOf[Row] => areRowsEqual(rr1, o2.asInstanceOf[Row], tol)
case _ => o1 == o2
}

if (!r1.isNullAt(idx)) {
val o1 = r1.get(idx)
val o2 = r2.get(idx)
o1 match {
case b1: Array[Byte] =>
if (
!java.util.Arrays.equals(
b1,
o2.asInstanceOf[Array[Byte]]
)
) {
return false
}

case f1: Float =>
if (
java.lang.Float.isNaN(f1) !=
java.lang.Float.isNaN(o2.asInstanceOf[Float])
) {
return false
}
if (abs(f1 - o2.asInstanceOf[Float]) > tol) {
return false
}

case d1: Double =>
if (
java.lang.Double.isNaN(d1) !=
java.lang.Double.isNaN(o2.asInstanceOf[Double])
) {
return false
}
if (abs(d1 - o2.asInstanceOf[Double]) > tol) {
return false
}

case d1: java.math.BigDecimal =>
if (d1.compareTo(o2.asInstanceOf[java.math.BigDecimal]) != 0) {
return false
}

case t1: Timestamp =>
if (abs(t1.getTime - o2.asInstanceOf[Timestamp].getTime) > tol) {
return false
}

case _ =>
if (o1 != o2) return false
}
if (!valid) {
return false
}
})
}
}
true
}

}
Original file line number Diff line number Diff line change
@@ -0,0 +1,52 @@
package com.github.mrpowers.spark.fast.tests

import scala.util.Try

object SeqLikesExtensions {
implicit class SeqExtensions[T](val seq1: Seq[T]) extends AnyVal {
def approximateSameElements(seq2: Seq[T], equals: (T, T) => Boolean): Boolean = (seq1, seq2) match {
case (i1: IndexedSeq[_], i2: IndexedSeq[_]) =>
val length = i1.length
var equal = length == i2.length
if (equal) {
var index = 0
val maxApplyCompare = {
val preferredLength =
Try(System.getProperty("scala.collection.immutable.IndexedSeq.defaultApplyPreferredMaxLength", "64").toInt).getOrElse(64)
if (length > (preferredLength.toLong << 1)) preferredLength else length
}
while (index < maxApplyCompare && equal) {
equal = equals(i1(index), i2(index))
index += 1
}
if ((index < length) && equal) {
val thisIt = i1.iterator.drop(index)
val thatIt = i2.iterator.drop(index)
while (equal && thisIt.hasNext) {
equal = equals(thisIt.next(), thatIt.next())
}
}
}
equal
case _ =>
val thisKnownSize = getKnownSize(seq1)
val knownSizeDifference = thisKnownSize != -1 && {
val thatKnownSize = getKnownSize(seq2)
thatKnownSize != -1 && thisKnownSize != thatKnownSize
}
if (knownSizeDifference) {
return false
}
val these = seq1.iterator
val those = seq2.iterator
while (these.hasNext && those.hasNext)
if (!equals(these.next(), those.next()))
return false
these.hasNext == those.hasNext
}

// scala2.13 optimization: check number of element if it can be cheaply computed
private def getKnownSize(s: Seq[T]): Int = Try(s.getClass.getMethod("knownSize").invoke(s).asInstanceOf[Int]).getOrElse(s.length)
}

}
Loading

0 comments on commit c933a8a

Please sign in to comment.