From d962bffdbc31dd7584767b2ebfeee430c75df700 Mon Sep 17 00:00:00 2001 From: Tuan Pham Date: Sat, 7 Sep 2024 17:25:48 +1000 Subject: [PATCH] Add test validate correct element colour diff --- .../spark/fast/tests/DataFrameComparer.scala | 2 +- .../spark/fast/tests/DataframeUtil.scala | 5 ++- .../spark/fast/tests/DatasetComparer.scala | 26 ++++++++----- .../spark/fast/tests/SchemaComparer.scala | 26 +++++++------ .../spark/fast/tests/ufansi/Fansi.scala | 9 +---- .../fast/tests/ufansi/FansiExtensions.scala | 2 +- .../fast/tests/DataFrameComparerTest.scala | 39 +++++++++++++++++++ .../fast/tests/DatasetComparerTest.scala | 6 +-- .../spark/fast/tests/SchemaComparerTest.scala | 6 ++- 9 files changed, 85 insertions(+), 36 deletions(-) diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala index 78d0a27..eed2d73 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparer.scala @@ -13,7 +13,7 @@ trait DataFrameComparer extends DatasetComparer { ignoreColumnNames: Boolean = false, orderedComparison: Boolean = true, ignoreColumnOrder: Boolean = false, - truncate: Int = 500, + truncate: Int = 500 ): Unit = { assertSmallDatasetEquality( actualDF, diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala index 1a1285c..6cfde87 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/DataframeUtil.scala @@ -34,7 +34,8 @@ object DataframeUtil { val withEquals = actualSeq .zip(expectedSeq) .map { case (actualRowField, expectedRowField) => - (actualRowField, expectedRowField, actualRowField == expectedRowField) } + (actualRowField, expectedRowField, actualRowField == expectedRowField) + } val allFieldsAreNotEqual = !withEquals.exists(_._3) if (allFieldsAreNotEqual) { List( @@ -45,7 +46,7 @@ object DataframeUtil { val coloredDiff = withEquals .map { - case (actualRowField, expectedRowField, true) => + case (actualRowField, expectedRowField, true) => (DarkGray(actualRowField.toString), DarkGray(expectedRowField.toString)) case (actualRowField, expectedRowField, false) => (Red(actualRowField.toString), Green(expectedRowField.toString)) diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala index 9f9c3cd..4f46329 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/DatasetComparer.scala @@ -29,8 +29,8 @@ Expected DataFrame Row Count: '$expectedCount' } /** - * order ds1 column according to ds2 column order - * */ + * order ds1 column according to ds2 column order + */ def orderColumns[T](ds1: Dataset[T], ds2: Dataset[T]): Dataset[T] = { ds1.select(ds2.columns.map(col).toIndexedSeq: _*).as[T](ds2.encoder) } @@ -53,7 +53,13 @@ Expected DataFrame Row Count: '$expectedCount' assertSmallDatasetContentEquality(actual, expectedDS, orderedComparison, truncate, equals) } - def assertSmallDatasetContentEquality[T](actualDS: Dataset[T], expectedDS: Dataset[T], orderedComparison: Boolean, truncate: Int, equals: (T, T) => Boolean): Unit = { + def assertSmallDatasetContentEquality[T]( + actualDS: Dataset[T], + expectedDS: Dataset[T], + orderedComparison: Boolean, + truncate: Int, + equals: (T, T) => Boolean + ): Unit = { if (orderedComparison) assertSmallDatasetContentEquality(actualDS, expectedDS, truncate, equals) else @@ -100,10 +106,12 @@ Expected DataFrame Row Count: '$expectedCount' assertLargeDatasetContentEquality(actual, expectedDS, equals, orderedComparison) } - def assertLargeDatasetContentEquality[T: ClassTag](actualDS: Dataset[T], - expectedDS: Dataset[T], - equals: (T, T) => Boolean, - orderedComparison: Boolean): Unit = { + def assertLargeDatasetContentEquality[T: ClassTag]( + actualDS: Dataset[T], + expectedDS: Dataset[T], + equals: (T, T) => Boolean, + orderedComparison: Boolean + ): Unit = { if (orderedComparison) { assertLargeDatasetContentEquality(actualDS, expectedDS, equals) } else { @@ -123,7 +131,7 @@ Expected DataFrame Row Count: '$expectedCount' throw DatasetCountMismatch(countMismatchMessage(actualCount, expectedCount)) } val expectedIndexValue = RddHelpers.zipWithIndex(ds1RDD) - val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD) + val resultIndexValue = RddHelpers.zipWithIndex(ds2RDD) val unequalRDD = expectedIndexValue .join(resultIndexValue) .filter { case (_, (o1, o2)) => @@ -169,4 +177,4 @@ Expected DataFrame Row Count: '$expectedCount' object DatasetComparer { val maxUnequalRowsToShow = 10 -} \ No newline at end of file +} diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala index 9ff8eef..19fb719 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/SchemaComparer.scala @@ -26,11 +26,13 @@ object SchemaComparer { .mkString("\n") } - def assertSchemaEqual[T](actualDS: Dataset[T], - expectedDS: Dataset[T], - ignoreNullable: Boolean = false, - ignoreColumnNames: Boolean = false, - ignoreColumnOrder: Boolean = true) = { + def assertSchemaEqual[T]( + actualDS: Dataset[T], + expectedDS: Dataset[T], + ignoreNullable: Boolean = false, + ignoreColumnNames: Boolean = false, + ignoreColumnOrder: Boolean = true + ) = { require((ignoreColumnNames, ignoreColumnOrder) != (true, true), "Cannot set both ignoreColumnNames and ignoreColumnOrder to true.") if (!SchemaComparer.equals(actualDS.schema, expectedDS.schema, ignoreNullable, ignoreColumnNames, ignoreColumnOrder)) { throw DatasetSchemaMismatch( @@ -39,11 +41,13 @@ object SchemaComparer { } } - def equals(s1: StructType, - s2: StructType, - ignoreNullable: Boolean = false, - ignoreColumnNames: Boolean = false, - ignoreColumnOrder: Boolean = true): Boolean = { + def equals( + s1: StructType, + s2: StructType, + ignoreNullable: Boolean = false, + ignoreColumnNames: Boolean = false, + ignoreColumnOrder: Boolean = true + ): Boolean = { if (s1.length != s2.length) { false } else { @@ -68,7 +72,7 @@ object SchemaComparer { equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) case (true, MapType(kdt1, vdt1, _), MapType(kdt2, vdt2, _)) => equals(kdt1, kdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) && - equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) + equals(vdt1, vdt2, ignoreNullable, ignoreColumnNames, ignoreColumnOrder) case _ => dt1 == dt2 } } diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/Fansi.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/Fansi.scala index 1891367..2e51488 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/Fansi.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/Fansi.scala @@ -298,13 +298,8 @@ object Str { * An [[ufansi.Str]]'s `color`s array is filled with Long, each representing the ANSI state of one character encoded in its bits. Each [[Attr]] * belongs to a [[Category]] that occupies a range of bits within each long: * - * 61... 55 54 53 52 51 .... 31 30 29 28 27 26 25 ..... 6 5 4 3 2 1 0 - * \|--------| |-----------------------| |-----------------------| | | |bold - * \| | | | |reversed - * \| | | |underlined - * \| | |foreground-color - * \| |background-color - * \|unused + * 61... 55 54 53 52 51 .... 31 30 29 28 27 26 25 ..... 6 5 4 3 2 1 0 \|--------| |-----------------------| |-----------------------| | | |bold \| | + * \| | |reversed \| | | |underlined \| | |foreground-color \| |background-color \|unused * * The `0000 0000 0000 0000` long corresponds to plain text with no decoration */ diff --git a/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala b/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala index 5af5d3d..702db2b 100644 --- a/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala +++ b/src/main/scala/com/github/mrpowers/spark/fast/tests/ufansi/FansiExtensions.scala @@ -4,4 +4,4 @@ object FansiExtensions { def mkStr(start: Str, sep: Str, end: Str): Str = start ++ c.reduce(_ ++ sep ++ _) ++ end } -} \ No newline at end of file +} diff --git a/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala b/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala index 6bcfddf..d8620b3 100644 --- a/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/fast/tests/DataFrameComparerTest.scala @@ -3,6 +3,7 @@ package com.github.mrpowers.spark.fast.tests import org.apache.spark.sql.types.{DoubleType, IntegerType, StringType} import SparkSessionExt._ import com.github.mrpowers.spark.fast.tests.SchemaComparer.DatasetSchemaMismatch +import com.github.mrpowers.spark.fast.tests.StringExt.StringOps import org.scalatest.freespec.AnyFreeSpec class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with SparkSessionTestWrapper { @@ -39,6 +40,44 @@ class DataFrameComparerTest extends AnyFreeSpec with DataFrameComparer with Spar assert(e.getMessage.indexOf("camila") >= 0) } + "Correctly mark unequal elements" in { + val sourceDF = spark.createDF( + List( + ("bob", 1, "uk"), + ("camila", 5, "peru"), + ("steve", 10, "aus") + ), + List( + ("name", StringType, true), + ("age", IntegerType, true), + ("country", StringType, true) + ) + ) + + val expectedDF = spark.createDF( + List( + ("bob", 1, "france"), + ("camila", 5, "peru"), + ("mark", 11, "usa") + ), + List( + ("name", StringType, true), + ("age", IntegerType, true), + ("country", StringType, true) + ) + ) + + val e = intercept[DatasetContentMismatch] { + assertSmallDataFrameEquality(expectedDF, sourceDF) + } + + val colourGroup = e.getMessage.extractColorGroup + val expectedColourGroup = colourGroup.get(Console.GREEN) + val actualColourGroup = colourGroup.get(Console.RED) + assert(expectedColourGroup.contains(Seq("uk", "[steve,10,aus]"))) + assert(actualColourGroup.contains(Seq("france", "[mark,11,usa]"))) + } + "works well for wide DataFrames" in { val sourceDF = spark.createDF( List( diff --git a/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala b/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala index 3832f2d..0ab6b27 100644 --- a/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/fast/tests/DatasetComparerTest.scala @@ -40,7 +40,7 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes } } - "Correctly mark unequal column" in { + "Correctly mark unequal elements" in { val sourceDS = Seq( Person("juan", 5), Person("bob", 1), @@ -59,9 +59,9 @@ class DatasetComparerTest extends AnyFreeSpec with DatasetComparer with SparkSes assertSmallDatasetEquality(sourceDS, expectedDS) } - val colourGroup = e.getMessage.extractColorGroup + val colourGroup = e.getMessage.extractColorGroup val expectedColourGroup = colourGroup.get(Console.GREEN) - val actualColourGroup = colourGroup.get(Console.RED) + val actualColourGroup = colourGroup.get(Console.RED) assert(expectedColourGroup.contains(Seq("[frank,10]", "lucy"))) assert(actualColourGroup.contains(Seq("[bob,1]", "alice"))) } diff --git a/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala b/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala index 103b330..c09f7b1 100644 --- a/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala +++ b/src/test/scala/com/github/mrpowers/spark/fast/tests/SchemaComparerTest.scala @@ -156,7 +156,8 @@ class SchemaComparerTest extends AnyFreeSpec { StructField("mood", ArrayType(StringType, containsNull = false), true), StructField("something", StringType, false) ) - )), + ) + ), true ) ) @@ -174,7 +175,8 @@ class SchemaComparerTest extends AnyFreeSpec { StructField("something", StringType, false), StructField("mood", ArrayType(StringType, containsNull = false), true) ) - )), + ) + ), true ) )