Skip to content

Commit

Permalink
Add color diff for small dataframe comparison
Browse files Browse the repository at this point in the history
  • Loading branch information
zem00n committed Aug 30, 2024
1 parent 6e53ec5 commit d553270
Show file tree
Hide file tree
Showing 2 changed files with 110 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.SeqLikesExtensions.SeqExtensions
import org.apache.spark.sql.{DataFrame, Row}

trait DataFrameComparer extends DatasetComparer {

/**
Expand All @@ -14,17 +14,24 @@ trait DataFrameComparer extends DatasetComparer {
ignoreColumnNames: Boolean = false,
orderedComparison: Boolean = true,
ignoreColumnOrder: Boolean = false,
truncate: Int = 500
truncate: Int = 500,
): Unit = {
assertSmallDatasetEquality(
actualDF,
expectedDF,
ignoreNullable,
ignoreColumnNames,
orderedComparison,
ignoreColumnOrder,
truncate
)
SchemaComparer.assertSchemaEqual(actualDF, expectedDF, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
val actual = if (ignoreColumnOrder) orderColumns(actualDF, expectedDF) else actualDF
if (orderedComparison)
assertSmallDataFrameEquality(actual, expectedDF, truncate)
else
assertSmallDataFrameEquality(defaultSortDataset(actual), defaultSortDataset(expectedDF), truncate)
}

def assertSmallDataFrameEquality(actualDF: DataFrame, expectedDF: DataFrame, truncate: Int): Unit = {
val a = actualDF.collect()
val e = expectedDF.collect()
if (!a.toSeq.approximateSameElements(e, (o1: Row, o2: Row) => o1.equals(o2))) {
val arr = ("Actual Content", "Expected Content")
val msg = "Diffs\n" ++ DataframeUtil.showDataframeDiff(arr, a, e, truncate)
throw DatasetContentMismatch(msg)
}
}

/**
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
package com.github.mrpowers.spark.fast.tests

import com.github.mrpowers.spark.fast.tests.ufansi.Color.{DarkGray, Green, Red}
import org.apache.commons.lang3.StringUtils
import org.apache.spark.sql.Row

object DataframeUtil {

def showDataframeDiff(
header: (String, String),
actual: Array[Row],
expected: Array[Row],
truncate: Int = 20
): String = {

val sb = new StringBuilder
val diff = actual.zip(expected).map { case (a, e) =>
if (equals(a, e)) {
List(ufansi.Color.DarkGray(a.toString()), ufansi.Color.DarkGray(e.toString()))
} else {
val d = a.toSeq
.zip(e.toSeq)
.map { case (a1, e1) =>
if (a1 == e1)
(DarkGray(a1.toString()), DarkGray(e1.toString))
else (Red(a1.toString()), Green(e1.toString))
}
List(
DarkGray("[") ++ d.map(_._1).reduce(_ ++ DarkGray(",") ++ _) ++ DarkGray("]"),
DarkGray("[") ++ d.map(_._2).reduce(_ ++ DarkGray(",") ++ _) ++ DarkGray("]")
)
}
}
val rows = Array(List(header._1, header._2))
val numCols = 2

// Initialise the width of each column to a minimum value of '3'
val colWidths = Array.fill(numCols)(3)

// Compute the width of each column
for (row <- rows) {
for ((cell, i) <- row.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), cell.length)
}
}
for (row <- diff) {
for ((cell, i) <- row.zipWithIndex) {
colWidths(i) = math.max(colWidths(i), cell.length)
}
}

// Create SeparateLine
val sep: String =
colWidths
.map("-" * _)
.addString(sb, "+", "+", "+\n")
.toString()

// column names
val h: Seq[(String, Int)] = rows.head.zipWithIndex
h.map { case (cell, i) =>
if (truncate > 0) {
StringUtils.leftPad(cell, colWidths(i))
} else {
StringUtils.rightPad(cell, colWidths(i))
}
}.addString(sb, "|", "|", "|\n")

sb.append(sep)

diff.map { row =>
row.zipWithIndex
.map { case (cell, i) =>
val padsLen = colWidths(i) - cell.length
val pads = if (padsLen > 0) " " * padsLen else ""
if (truncate > 0) {
pads + cell.toString
} else {
cell.toString + pads
}

}
.addString(sb, "|", "|", "|\n")
}

sb.append(sep)

sb.toString()
}


}

0 comments on commit d553270

Please sign in to comment.