Skip to content

Commit

Permalink
Add Table support for StructField Diff
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Sep 28, 2024
1 parent e49b44f commit 9db07d8
Show file tree
Hide file tree
Showing 4 changed files with 96 additions and 58 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount'
/**
* Raises an error unless `actualDS` and `expectedDS` are equal
*/
def assertSmallDatasetEquality[T](
def assertSmallDatasetEquality[T: ClassTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
Expand All @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}

def assertSmallDatasetContentEquality[T](
def assertSmallDatasetContentEquality[T: ClassTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
orderedComparison: Boolean,
Expand All @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals)
}

def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = {
val a = actualDS.collect().toSeq
val e = expectedDS.collect().toSeq
if (!a.approximateSameElements(e, equals)) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate)
val msg = "Diffs\n" ++ ProductUtil.showProductDiff[T](arr, a, e, truncate)
throw DatasetContentMismatch(msg)
}
}
Expand Down
Original file line number Diff line number Diff line change
@@ -1,35 +1,48 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row
import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps
object DataframeUtil {

private[mrpowers] def showDataframeDiff(
import scala.reflect.ClassTag

object ProductUtil {
private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = {
product match {
case null => Seq.empty
case r: Row => r.toSeq
case p: Product => p.productIterator.toSeq
case _ => throw new IllegalArgumentException("Only Row and Product types are supported")
}
}
private[mrpowers] def showProductDiff[T: ClassTag](
header: (String, String),
actual: Seq[Row],
expected: Seq[Row],
actual: Seq[T],
expected: Seq[T],
truncate: Int = 20,
minColWidth: Int = 3
minColWidth: Int = 3,
defaultVal: T = null.asInstanceOf[T],
border: (String, String) = ("[", "]")
): String = {
val className = implicitly[ClassTag[T]].runtimeClass.getSimpleName
val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2)
val emptyProd = s"$className()"

val sb = new StringBuilder

val fullJoin = actual.zipAll(expected, Row(), Row())
val fullJoin = actual.zipAll(expected, defaultVal, defaultVal)

val diff = fullJoin.map { case (actualRow, expectedRow) =>
if (equals(actualRow, expectedRow)) {
if (actualRow == expectedRow) {
List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString))
} else {
val actualSeq = actualRow.toSeq
val expectedSeq = expectedRow.toSeq
val actualSeq = productOrRowToSeq(actualRow)
val expectedSeq = productOrRowToSeq(expectedRow)
if (actualSeq.isEmpty)
List(
Red("[]"),
Green(expectedSeq.mkString("[", ",", "]"))
)
List(Red(emptyProd), Green(prodToString(expectedSeq)))
else if (expectedSeq.isEmpty)
List(Red(actualSeq.mkString("[", ",", "]")), Green("[]"))
List(Red(prodToString(actualSeq)), Green(emptyProd))
else {
val withEquals = actualSeq
.zip(expectedSeq)
Expand All @@ -38,22 +51,18 @@ object DataframeUtil {
}
val allFieldsAreNotEqual = !withEquals.exists(_._3)
if (allFieldsAreNotEqual) {
List(
Red(actualSeq.mkString("[", ",", "]")),
Green(expectedSeq.mkString("[", ",", "]"))
)
List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq)))
} else {

val coloredDiff = withEquals
.map {
case (actualRowField, expectedRowField, true) =>
(DarkGray(actualRowField.toString), DarkGray(expectedRowField.toString))
case (actualRowField, expectedRowField, false) =>
(Red(actualRowField.toString), Green(expectedRowField.toString))
}
val start = DarkGray("[")
val start = DarkGray(s"$className${border._1}")
val sep = DarkGray(",")
val end = DarkGray("]")
val end = DarkGray(border._2)
List(
coloredDiff.map(_._1).mkStr(start, sep, end),
coloredDiff.map(_._2).mkStr(start, sep, end)
Expand All @@ -69,11 +78,12 @@ object DataframeUtil {
val colWidths = Array.fill(numCols)(minColWidth)

// Compute the width of each column
for ((cell, i) <- headerSeq.zipWithIndex) {
headerSeq.zipWithIndex.foreach({ case (cell, i) =>
colWidths(i) = math.max(colWidths(i), cell.length)
}
for (row <- diff) {
for ((cell, i) <- row.zipWithIndex) {
})

diff.foreach { row =>
row.zipWithIndex.foreach { case (cell, i) =>
colWidths(i) = math.max(colWidths(i), cell.length)
}
}
Expand Down Expand Up @@ -117,5 +127,4 @@ object DataframeUtil {

sb.toString
}

}
Original file line number Diff line number Diff line change
@@ -1,29 +1,20 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff
import org.apache.spark.sql.Dataset
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType}
import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType}

object SchemaComparer {

case class DatasetSchemaMismatch(smth: String) extends Exception(smth)
private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = {
"\nActual Schema Field | Expected Schema Field\n" + actualDS.schema
.zipAll(
expectedDS.schema,
"",
""
)
.map {
case (sf1, sf2) if sf1 == sf2 =>
ufansi.Color.Blue(s"$sf1 | $sf2")
case ("", sf2) =>
ufansi.Color.Red(s"MISSING | $sf2")
case (sf1, "") =>
ufansi.Color.Red(s"$sf1 | MISSING")
case (sf1, sf2) =>
ufansi.Color.Red(s"$sf1 | $sf2")
}
.mkString("\n")
showProductDiff(
("Actual Schema", "Expected Schema"),
actualDS.schema.fields,
expectedDS.schema.fields,
truncate = 200,
defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType),
border = ("(", ")")
)
}

def assertSchemaEqual[T](
Expand All @@ -36,7 +27,7 @@ object SchemaComparer {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) {
throw DatasetSchemaMismatch(
betterSchemaMismatchMessage(actualDS, expectedDS)
"Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS)
)
}
}
Expand Down Expand Up @@ -76,5 +67,4 @@ object SchemaComparer {
case _ => dt1 == dt2
}
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -154,31 +154,70 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes
}

"throws an error if the DataFrames have different schemas" in {
val nestedSchema = StructType(
Seq(
StructField(
"attributes",
StructType(
Seq(
StructField("PostCode", IntegerType, nullable = true)
)
),
nullable = true
)
)
)

val nestedSchema2 = StructType(
Seq(
StructField(
"attributes",
StructType(
Seq(
StructField("PostCode", StringType, nullable = true)
)
),
nullable = true
)
)
)

val sourceDF = spark.createDF(
List(
(1),
(5)
(1, 2.0, null),
(5, 3.0, null)
),
List(("number", IntegerType, true))
List(
("number", IntegerType, true),
("float", DoubleType, true),
("nestedField", nestedSchema, true)
)
)

val expectedDF = spark.createDF(
List(
(1, "word"),
(5, "word")
(1, "word", null, 1L),
(5, "word", null, 2L)
),
List(
("number", IntegerType, true),
("word", StringType, true)
("word", StringType, true),
("nestedField", nestedSchema2, true),
("long", LongType, true)
)
)

val e = intercept[DatasetSchemaMismatch] {
assertLargeDatasetEquality(sourceDF, expectedDF)
}
println(e)
val e2 = intercept[DatasetSchemaMismatch] {
assertSmallDatasetEquality(sourceDF, expectedDF)
}
println(e2)

sourceDF.schema.printTreeString()
expectedDF.schema.printTreeString()
}

"throws an error if the DataFrames content is different" in {
Expand Down

0 comments on commit 9db07d8

Please sign in to comment.