diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala index ab2da0e..4840ee0 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/ProductUtil.scala @@ -29,7 +29,7 @@ object ProductUtil { val className = runTimeClass.getSimpleName val border = if (runTimeClass == classOf[Row]) ("[", "]") else ("(", ")") val prodToString: Seq[Any] => String = s => s.mkString(s"$className${border._1}", ",", border._2) - val emptyProd = s"$className${border._1}${border._2}" + val emptyProd = "MISSING" val sb = new StringBuilder diff --git a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index f67fa73..89f6783 100644 --- a/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/core/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -11,8 +11,7 @@ object SchemaComparer { ("Actual Schema", "Expected Schema"), actualDS.schema.fields, expectedDS.schema.fields, - truncate = 200, - defaultVal = StructField("SPARK_FAST_TEST_MISSING_FIELD", NullType) + truncate = 200 ) } diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index a8b84f1..09daa83 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -1,6 +1,6 @@ package com.github.mrpowers.spark.fast.tests -import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType} +import org.apache.spark.sql.types.{DoubleType, IntegerType, LongType, StringType} import SparkSessionExt._ import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch import com.github.mrpowers.spark.fast.tests.StringExt.StringOps @@ -310,6 +310,41 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar ) assertLargeDataFrameEquality(sourceDF, expectedDF, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDataFrameEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertApproximateDataFrameEquality" - { diff --git a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index f645661..fb5e9e4 100644 --- a/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/core/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -207,17 +207,13 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes ) ) - val e = intercept[DatasetSchemaMismatch] { + intercept[DatasetSchemaMismatch] { assertLargeDatasetEquality(sourceDF, expectedDF) } - println(e) - val e2 = intercept[DatasetSchemaMismatch] { + + intercept[DatasetSchemaMismatch] { assertSmallDatasetEquality(sourceDF, expectedDF) } - println(e2) - - sourceDF.schema.printTreeString() - expectedDF.schema.printTreeString() } "throws an error if the DataFrames content is different" in { @@ -446,6 +442,41 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes assertLargeDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertLargeDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertLargeDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "assertSmallDatasetEquality" - { @@ -605,9 +636,43 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes Person("alice", 5) ).toDS.select("age", "name").as(ds1.encoder) - assertSmallDatasetEquality(ds1, ds2, ignoreColumnOrder = true) assertSmallDatasetEquality(ds2, ds1, ignoreColumnOrder = true) } + + "correctly mark unequal schema field" in { + val sourceDF = spark.createDF( + List( + (1, 2.0), + (5, 3.0) + ), + List( + ("number", IntegerType, true), + ("float", DoubleType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + (1, "word", 1L), + (5, "word", 2L) + ), + List( + ("number", IntegerType, true), + ("word", StringType, true), + ("long", LongType, true) + ) + ) + + val e = intercept[DatasetSchemaMismatch] { + assertSmallDatasetEquality(sourceDF, expectedDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("word", "StringType", "StructField(long,LongType,true,{})"))) + assert(actualColourGroup.contains(Seq("float", "DoubleType", "MISSING"))) + } } "defaultSortDataset" - {