diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala similarity index 65% rename from core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala rename to core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala index ffbe668..6cfde87 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala @@ -1,48 +1,35 @@ package com.github.mrpowers.spark.fast.tests import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red} -import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps import org.apache.commons.lang3.StringUtils import org.apache.spark.sql.Row +import com.github.mrpowers.spark.fast.tests.ufansi.FansiExtensions.StrOps +object DataframeUtil { -import scala.reflect.ClassTag - -object ProductUtil { - private[mrpowers] def productOrRowToSeq(product: Any): Seq[Any] = { - product match { - case null => Seq.empty - case r: Row => r.toSeq - case p: Product => p.productIterator.toSeq - case _ => throw new IllegalArgumentException("Only Row and Product types are supported") - } - } - private[mrpowers] def showProductDiff[T: ClassTag]( + private[mrpowers] def showDataframeDiff( header: (String, String), - actual: Seq[T], - expected: Seq[T], + actual: Seq[Row], + expected: Seq[Row], truncate: Int = 20, - minColWidth: Int = 3, - defaultVal: T = null.asInstanceOf[T], - border: (String, String) = ("[", "]") + minColWidth: Int = 3 ): String = { - val className = implicitly[ClassTag[T]].runtimeClass.getSimpleName - val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) - val emptyProd = s"$className()" val sb = new StringBuilder - val fullJoin = actual.zipAll(expected, defaultVal, defaultVal) - + val fullJoin = actual.zipAll(expected, Row(), Row()) val diff = fullJoin.map { case (actualRow, expectedRow) => - if (actualRow == expectedRow) { + if (equals(actualRow, expectedRow)) { List(DarkGray(actualRow.toString), DarkGray(expectedRow.toString)) } else { - val actualSeq = productOrRowToSeq(actualRow) - val expectedSeq = productOrRowToSeq(expectedRow) + val actualSeq = actualRow.toSeq + val expectedSeq = expectedRow.toSeq if (actualSeq.isEmpty) - List(Red(emptyProd), Green(prodToString(expectedSeq))) + List( + Red("[]"), + Green(expectedSeq.mkString("[", ",", "]")) + ) else if (expectedSeq.isEmpty) - List(Red(prodToString(actualSeq)), Green(emptyProd)) + List(Red(actualSeq.mkString("[", ",", "]")), Green("[]")) else { val withEquals = actualSeq .zip(expectedSeq) @@ -51,8 +38,12 @@ object ProductUtil { } val allFieldsAreNotEqual = !withEquals.exists(_._3) if (allFieldsAreNotEqual) { - List(Red(prodToString(actualSeq)), Green(prodToString(expectedSeq))) + List( + Red(actualSeq.mkString("[", ",", "]")), + Green(expectedSeq.mkString("[", ",", "]")) + ) } else { + val coloredDiff = withEquals .map { case (actualRowField, expectedRowField, true) => @@ -60,9 +51,9 @@ object ProductUtil { case (actualRowField, expectedRowField, false) => (Red(actualRowField.toString), Green(expectedRowField.toString)) } - val start = DarkGray(s"$className${border._1}") + val start = DarkGray("[") val sep = DarkGray(",") - val end = DarkGray(border._2) + val end = DarkGray("]") List( coloredDiff.map(_._1).mkStr(start, sep, end), coloredDiff.map(_._2).mkStr(start, sep, end) @@ -78,12 +69,11 @@ object ProductUtil { val colWidths = Array.fill(numCols)(minColWidth) // Compute the width of each column - headerSeq.zipWithIndex.foreach({ case (cell, i) => + for ((cell, i) <- headerSeq.zipWithIndex) { colWidths(i) = math.max(colWidths(i), cell.length) - }) - - diff.foreach { row => - row.zipWithIndex.foreach { case (cell, i) => + } + for (row <- diff) { + for ((cell, i) <- row.zipWithIndex) { colWidths(i) = math.max(colWidths(i), cell.length) } } @@ -127,4 +117,5 @@ object ProductUtil { sb.toString } + } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index e71a115..70b30dc 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -38,7 +38,7 @@ Expected DataFrame Row Count: '$expectedCount' /** * Raises an error unless `actualDS` and `expectedDS` are equal */ - def assertSmallDatasetEquality[T: ClassTag]( + def assertSmallDatasetEquality[T]( actualDS: Dataset[T], expectedDS: Dataset[T], ignoreNullable: Boolean = false, @@ -53,7 +53,7 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals) } - def assertSmallDatasetContentEquality[T: ClassTag]( + def assertSmallDatasetContentEquality[T]( actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, @@ -66,12 +66,12 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(defaultSortDataset(actualDS), defaultSortDataset(expectedDS), truncate, equals) } - def assertSmallDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { + def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], truncate: Int, equals: (T, T) => Boolean): Unit = { val a = actualDS.collect().toSeq val e = expectedDS.collect().toSeq if (!a.approximateSameElements(e, equals)) { val arr = ("Actual Content", "Expected Content") - val msg = "Diffs\n" ++ ProductUtil.showProductDiff[T](arr, a, e, truncate) + val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a.asRows, e.asRows, truncate) throw DatasetContentMismatch(msg) } } diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index 9316d76..ce1edfe 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -1,20 +1,29 @@ package com.github.mrpowers.spark.fast.tests -import com.github.mrpowers.spark.fast.tests.ProductUtil.showProductDiff import org.apache.spark.sql.Dataset -import org.apache.spark.sql.types.{ArrayType, DataType, MapType, NullType, StructField, StructType} +import org.apache.spark.sql.types.{ArrayType, DataType, MapType, StructType} object SchemaComparer { + case class DatasetSchemaMismatch(smth: String) extends Exception(smth) private def betterSchemaMismatchMessage[T](actualDS: Dataset[T], expectedDS: Dataset[T]): String = { - showProductDiff( - ("Actual Schema", "Expected Schema"), - actualDS.schema.fields, - expectedDS.schema.fields, - truncate = 200, - defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType), - border = ("(", ")") - ) + "\nActual Schema Field | Expected Schema Field\n" + actualDS.schema + .zipAll( + expectedDS.schema, + "", + "" + ) + .map { + case (sf1, sf2) if sf1 == sf2 => + ufansi.Color.Blue(s"$sf1 | $sf2") + case ("", sf2) => + ufansi.Color.Red(s"MISSING | $sf2") + case (sf1, "") => + ufansi.Color.Red(s"$sf1 | MISSING") + case (sf1, sf2) => + ufansi.Color.Red(s"$sf1 | $sf2") + } + .mkString("\n") } def assertSchemaEqual[T]( @@ -27,7 +36,7 @@ object SchemaComparer { require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.") if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) { throw DatasetSchemaMismatch( - "Diffs\n" + betterSchemaMismatchMessage(actualDS, expectedDS) + betterSchemaMismatchMessage(actualDS, expectedDS) ) } } @@ -67,4 +76,5 @@ object SchemaComparer { case _ => dt1 == dt2 } } + } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index 284911d..0ab6b27 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -154,70 +154,31 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes } "throws an error if the DataFrames have different schemas" in { - val nestedSchema = StructType( - Seq( - StructField( - "attributes", - StructType( - Seq( - StructField("PostCode", IntegerType, nullable = true) - ) - ), - nullable = true - ) - ) - ) - - val nestedSchema2 = StructType( - Seq( - StructField( - "attributes", - StructType( - Seq( - StructField("PostCode", StringType, nullable = true) - ) - ), - nullable = true - ) - ) - ) - val sourceDF = spark.createDF( List( - (1, 2.0, null), - (5, 3.0, null) + (1), + (5) ), - List( - ("number", IntegerType, true), - ("float", DoubleType, true), - ("nestedField", nestedSchema, true) - ) + List(("number", IntegerType, true)) ) val expectedDF = spark.createDF( List( - (1, "word", null, 1L), - (5, "word", null, 2L) + (1, "word"), + (5, "word") ), List( ("number", IntegerType, true), - ("word", StringType, true), - ("nestedField", nestedSchema2, true), - ("long", LongType, true) + ("word", StringType, true) ) ) val e = intercept[DatasetSchemaMismatch] { assertLargeDatasetEquality(sourceDF, expectedDF) } - println(e) val e2 = intercept[DatasetSchemaMismatch] { assertSmallDatasetEquality(sourceDF, expectedDF) } - println(e2) - - sourceDF.schema.printTreeString() - expectedDF.schema.printTreeString() } "throws an error if the DataFrames content is different" in {