Skip to content

Commit

Permalink
Add test validate correct element colour diff
Browse files Browse the repository at this point in the history
  • Loading branch information
zeotuan committed Sep 7, 2024
1 parent dd35f31 commit d962bff
Show file tree
Hide file tree
Showing 9 changed files with 85 additions and 36 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ trait DataFrameComparer extends DatasetComparer {
ignoreColumnNames: Boolean = false,
orderedComparison: Boolean = true,
ignoreColumnOrder: Boolean = false,
truncate: Int = 500,
truncate: Int = 500
): Unit = {
assertSmallDatasetEquality(
actualDF,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,8 @@ object DataframeUtil {
val withEquals = actualSeq
.zip(expectedSeq)
.map { case (actualRowField, expectedRowField) =>
(actualRowField, expectedRowField, actualRowField == expectedRowField) }
(actualRowField, expectedRowField, actualRowField == expectedRowField)
}
val allFieldsAreNotEqual = !withEquals.exists(_._3)
if (allFieldsAreNotEqual) {
List(
Expand All @@ -45,7 +46,7 @@ object DataframeUtil {

val coloredDiff = withEquals
.map {
case (actualRowField, expectedRowField, true) =>
case (actualRowField, expectedRowField, true) =>
(DarkGray(actualRowField.toString), DarkGray(expectedRowField.toString))
case (actualRowField, expectedRowField, false) =>
(Red(actualRowField.toString), Green(expectedRowField.toString))
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,8 @@ Expected DataFrame Row Count: '$expectedCount'
}

/**
* order ds1 column according to ds2 column order
* */
* order ds1 column according to ds2 column order
*/
def orderColumns[T](ds1: Dataset[T], ds2: Dataset[T]): Dataset[T] = {
ds1.select(ds2.columns.map(col).toIndexedSeq: _*).as[T](ds2.encoder)
}
Expand All @@ -53,7 +53,13 @@ Expected DataFrame Row Count: '$expectedCount'
assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals)
}

def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, truncate: Int, equals: (T, T) => Boolean): Unit = {
def assertSmallDatasetContentEquality[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
orderedComparison: Boolean,
truncate: Int,
equals: (T, T) => Boolean
): Unit = {
if (orderedComparison)
assertSmallDatasetContentEquality(actualDS, expectedDS, truncate, equals)
else
Expand Down Expand Up @@ -100,10 +106,12 @@ Expected DataFrame Row Count: '$expectedCount'
assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison)
}

def assertLargeDatasetContentEquality[T: ClassTag](actualDS: Dataset[T],
expectedDS: Dataset[T],
equals: (T, T) => Boolean,
orderedComparison: Boolean): Unit = {
def assertLargeDatasetContentEquality[T: ClassTag](
actualDS: Dataset[T],
expectedDS: Dataset[T],
equals: (T, T) => Boolean,
orderedComparison: Boolean
): Unit = {
if (orderedComparison) {
assertLargeDatasetContentEquality(actualDS, expectedDS, equals)
} else {
Expand All @@ -123,7 +131,7 @@ Expected DataFrame Row Count: '$expectedCount'
throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount))
}
val expectedIndexValue = RddHelpers.zipWithIndex(ds1RDD)
val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD)
val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD)
val unequalRDD = expectedIndexValue
.join(resultIndexValue)
.filter { case (_, (o1, o2)) =>
Expand Down Expand Up @@ -169,4 +177,4 @@ Expected DataFrame Row Count: '$expectedCount'

object DatasetComparer {
val maxUnequalRowsToShow = 10
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,13 @@ object SchemaComparer {
.mkString("\n")
}

def assertSchemaEqual[T](actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true) = {
def assertSchemaEqual[T](
actualDS: Dataset[T],
expectedDS: Dataset[T],
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true
) = {
require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.")
if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) {
throw DatasetSchemaMismatch(
Expand All @@ -39,11 +41,13 @@ object SchemaComparer {
}
}

def equals(s1: StructType,
s2: StructType,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true): Boolean = {
def equals(
s1: StructType,
s2: StructType,
ignoreNullable: Boolean = false,
ignoreColumnNames: Boolean = false,
ignoreColumnOrder: Boolean = true
): Boolean = {
if (s1.length != s2.length) {
false
} else {
Expand All @@ -68,7 +72,7 @@ object SchemaComparer {
equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
case (true, MapType(kdt1, vdt1, _), MapType(kdt2, vdt2, _)) =>
equals(kdt1, kdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) &&
equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)
case _ => dt1 == dt2
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -298,13 +298,8 @@ object Str {
* An [[ufansi.Str]]'s `color`s array is filled with Long, each representing the ANSI state of one character encoded in its bits. Each [[Attr]]
* belongs to a [[Category]] that occupies a range of bits within each long:
*
* 61... 55 54 53 52 51 .... 31 30 29 28 27 26 25 ..... 6 5 4 3 2 1 0
* \|--------| |-----------------------| |-----------------------| | | |bold
* \| | | | |reversed
* \| | | |underlined
* \| | |foreground-color
* \| |background-color
* \|unused
* 61... 55 54 53 52 51 .... 31 30 29 28 27 26 25 ..... 6 5 4 3 2 1 0 \|--------| |-----------------------| |-----------------------| | | |bold \| |
* \| | |reversed \| | | |underlined \| | |foreground-color \| |background-color \|unused
*
* The `0000 0000 0000 0000` long corresponds to plain text with no decoration
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,4 @@ object FansiExtensions {
def mkStr(start: Str, sep: Str, end: Str): Str =
start ++ c.reduce(_ ++ sep ++ _) ++ end
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package com.github.mrpowers.spark.fast.tests
import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType}
import SparkSessionExt._
import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch
import com.github.mrpowers.spark.fast.tests.StringExt.StringOps
import org.scalatest.freespec.AnyFreeSpec

class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with SparkSessionTestWrapper {
Expand Down Expand Up @@ -39,6 +40,44 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar
assert(e.getMessage.indexOf("camila") >= 0)
}

"Correctly mark unequal elements" in {
val sourceDF = spark.createDF(
List(
("bob", 1, "uk"),
("camila", 5, "peru"),
("steve", 10, "aus")
),
List(
("name", StringType, true),
("age", IntegerType, true),
("country", StringType, true)
)
)

val expectedDF = spark.createDF(
List(
("bob", 1, "france"),
("camila", 5, "peru"),
("mark", 11, "usa")
),
List(
("name", StringType, true),
("age", IntegerType, true),
("country", StringType, true)
)
)

val e = intercept[DatasetContentMismatch] {
assertSmallDataFrameEquality(expectedDF, sourceDF)
}

val colourGroup = e.getMessage.extractColorGroup
val expectedColourGroup = colourGroup.get(Console.GREEN)
val actualColourGroup = colourGroup.get(Console.RED)
assert(expectedColourGroup.contains(Seq("uk", "[steve,10,aus]")))
assert(actualColourGroup.contains(Seq("france", "[mark,11,usa]")))
}

"works well for wide DataFrames" in {
val sourceDF = spark.createDF(
List(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes
}
}

"Correctly mark unequal column" in {
"Correctly mark unequal elements" in {
val sourceDS = Seq(
Person("juan", 5),
Person("bob", 1),
Expand All @@ -59,9 +59,9 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes
assertSmallDatasetEquality(sourceDS, expectedDS)
}

val colourGroup = e.getMessage.extractColorGroup
val colourGroup = e.getMessage.extractColorGroup
val expectedColourGroup = colourGroup.get(Console.GREEN)
val actualColourGroup = colourGroup.get(Console.RED)
val actualColourGroup = colourGroup.get(Console.RED)
assert(expectedColourGroup.contains(Seq("[frank,10]", "lucy")))
assert(actualColourGroup.contains(Seq("[bob,1]", "alice")))
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -156,7 +156,8 @@ class SchemaComparerTest extends AnyFreeSpec {
StructField("mood", ArrayType(StringType, containsNull = false), true),
StructField("something", StringType, false)
)
)),
)
),
true
)
)
Expand All @@ -174,7 +175,8 @@ class SchemaComparerTest extends AnyFreeSpec {
StructField("something", StringType, false),
StructField("mood", ArrayType(StringType, containsNull = false), true)
)
)),
)
),
true
)
)
Expand Down

0 comments on commit d962bff

Please sign in to comment.