Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add allow_nan_equality option to assert_approx_df_equality #29

Draft
wants to merge 12 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 17 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
# Changelog

All notable changes to this project will be documented in this file.

The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).

## Unreleased

### Changed
- `DataFramesNotEqualError` changed to `RowsNotEqualError` to reflect it being raised when testing for row equality.
- The assertion functions `assert_df_equality` and `assert_column_equality` now have optional `precision` parameter to test for approximate equality.

### Removed
- Removed `are_dfs_equal` because it has been superseded by other parts of the API.
- Removed `assert_approx_df_equality` as it has been replaced by adding the optional `precision` parameter to `assert_df_equality`.
- Removed `assert_approx_column_equality` as it has been replaced by adding the optional `precision` parameter to `assert_column_equality`.
12 changes: 6 additions & 6 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -307,7 +307,7 @@ def test_approx_col_equality_same():
(None, None)
]
df = spark.createDataFrame(data, ["num1", "num2"])
assert_approx_column_equality(df, "num1", "num2", 0.1)
assert_column_equality(df, "num1", "num2", precision=0.1)
```

Here's an example of a test with columns that are not approximately equal.
Expand All @@ -321,7 +321,7 @@ def test_approx_col_equality_different():
(None, None)
]
df = spark.createDataFrame(data, ["num1", "num2"])
assert_approx_column_equality(df, "num1", "num2", 0.1)
assert_column_equality(df, "num1", "num2", precision=0.1)
```

This failing test will output a readable error message so the issue is easy to debug.
Expand Down Expand Up @@ -350,10 +350,10 @@ def test_approx_df_equality_same():
]
df2 = spark.createDataFrame(data2, ["num", "letter"])

assert_approx_df_equality(df1, df2, 0.1)
assert_df_equality(df1, df2, precision=0.1)
```

The `assert_approx_df_equality` method is smart and will only perform approximate equality operations for floating point numbers in DataFrames. It'll perform regular equality for strings and other types.
The `assert_df_equality` method has a `precision` parameter that let's the user control the absolute tolerance of any floating point errors that are accepted by the assertion method. It is smart and will only perform approximate equality operations for floating point numbers in DataFrames. It'll perform regular equality for strings and other types.

Let's perform an approximate equality comparison for two DataFrames that are not equal.

Expand All @@ -375,7 +375,7 @@ def test_approx_df_equality_different():
]
df2 = spark.createDataFrame(data2, ["num", "letter"])

assert_approx_df_equality(df1, df2, 0.1)
assert_df_equality(df1, df2, precision=0.1)
```

Here's the pretty error message that's outputted:
Expand All @@ -384,7 +384,7 @@ Here's the pretty error message that's outputted:

## Schema mismatch messages

DataFrame equality messages peform schema comparisons before analyzing the actual content of the DataFrames. DataFrames that don't have the same schemas should error out as fast as possible.
DataFrame equality messages perform schema comparisons before analyzing the actual content of the DataFrames. DataFrames that don't have the same schemas should error out as fast as possible.

Let's compare a DataFrame that has a string column an integer column with a DataFrame that has two integer columns to observe the schema mismatch message.

Expand Down
5 changes: 3 additions & 2 deletions chispa/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,5 +25,6 @@
print("Can't find Apache Spark. Please set environment variable SPARK_HOME to root of installation!")
exit(-1)

from .dataframe_comparer import DataFramesNotEqualError, assert_df_equality, assert_approx_df_equality
from .column_comparer import ColumnsNotEqualError, assert_column_equality, assert_approx_column_equality
from .dataframe_comparer import assert_df_equality
from .column_comparer import assert_column_equality, ColumnsNotEqualError
from .row_comparer import RowsNotEqualError
91 changes: 56 additions & 35 deletions chispa/column_comparer.py
Original file line number Diff line number Diff line change
@@ -1,53 +1,74 @@
from chispa.bcolors import *
from typing import Optional, Any

from pyspark.sql import DataFrame
from pyspark.sql.types import DataType

from chispa.bcolors import blue
from chispa.prettytable import PrettyTable
from chispa.number_helpers import check_equal


class ColumnsNotEqualError(Exception):
"""The columns are not equal"""
pass


def assert_column_equality(df, col_name1, col_name2):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
if colName1Elements != colName2Elements:
zipped = list(zip(colName1Elements, colName2Elements))
t = PrettyTable([col_name1, col_name2])
for elements in zipped:
if elements[0] == elements[1]:
first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed
second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed
t.add_row([first, second])
else:
t.add_row([str(elements[0]), str(elements[1])])
raise ColumnsNotEqualError("\n" + t.get_string())
def assert_column_equality(
df: DataFrame,
col_name1: str,
col_name2: str,
precision: Optional[float] = None,
allow_nan_equality: bool = False,
) -> None:
"""Assert that two columns in a PySpark DataFrame are equal.

Parameters
----------
precision : float, optional
Absolute tolerance when checking for equality.
allow_nan_equality : bool, default False
When True, treats two NaN values as equal.

def assert_approx_column_equality(df, col_name1, col_name2, precision):
elements = df.select(col_name1, col_name2).collect()
colName1Elements = list(map(lambda x: x[0], elements))
colName2Elements = list(map(lambda x: x[1], elements))
"""
all_rows_equal = True
zipped = list(zip(colName1Elements, colName2Elements))
t = PrettyTable([col_name1, col_name2])

# Zip both columns together for iterating through elements.
columns = df.select(col_name1, col_name2).collect()
zipped = zip(*[list(map(lambda x: x[i], columns)) for i in [0, 1]])

for elements in zipped:
first = bcolors.LightBlue + str(elements[0]) + bcolors.LightRed
second = bcolors.LightBlue + str(elements[1]) + bcolors.LightRed
# when one is None and the other isn't, they're not equal
if (elements[0] == None and elements[1] != None) or (elements[0] != None and elements[1] == None):
all_rows_equal = False
t.add_row([str(elements[0]), str(elements[1])])
# when both are None, they're equal
elif elements[0] == None and elements[1] == None:
t.add_row([first, second])
# when the diff is less than the threshhold, they're approximately equal
elif abs(elements[0] - elements[1]) < precision:
t.add_row([first, second])
# otherwise, they're not equal
if are_elements_equal(*elements, precision, allow_nan_equality):
t.add_row([blue(e) for e in elements])
else:
all_rows_equal = False
t.add_row([str(elements[0]), str(elements[1])])
t.add_row([str(e) for e in elements])

if all_rows_equal == False:
raise ColumnsNotEqualError("\n" + t.get_string())


def are_elements_equal(
e1: DataType,
e2: DataType,
precision: Optional[float] = None,
allow_nan_equality: bool = False,
) -> bool:
"""
Return True if both elements are equal.

Parameters
----------
precision : float, optional
Absolute tolerance when checking for equality.
allow_nan_equality: bool, default False
When True, treats two NaN values as equal.

"""
# If both elements are None they are considered equal.
if e1 is None and e2 is None:
return True
if (e1 is None and e2 is not None) or (e2 is None and e1 is not None):
return False

return check_equal(e1, e2, precision, allow_nan_equality)
103 changes: 38 additions & 65 deletions chispa/dataframe_comparer.py
Original file line number Diff line number Diff line change
@@ -1,78 +1,51 @@
from chispa.prettytable import PrettyTable
from chispa.bcolors import *
from chispa.schema_comparer import assert_schema_equality
from chispa.row_comparer import *
import chispa.six as six
from functools import reduce
from typing import Callable, Optional

from pyspark.sql import DataFrame

class DataFramesNotEqualError(Exception):
"""The DataFrames are not equal"""
pass


def assert_df_equality(df1, df2, ignore_nullable=False, transforms=None, allow_nan_equality=False, ignore_column_order=False, ignore_row_order=False):
from chispa.schema_comparer import assert_schema_equality
from chispa.row_comparer import assert_rows_equality


def assert_df_equality(
df1: DataFrame,
df2: DataFrame,
precision: Optional[float] = None,
ignore_nullable: bool = False,
allow_nan_equality: bool = False,
ignore_column_order: bool = False,
ignore_row_order: bool = False,
transforms: Callable[[DataFrame], DataFrame] = None,
) -> None:
"""Assert that two PySpark DataFrames are equal.

Parameters
----------
precision : float, optional
Absolute tolerance when checking for equality.
ignore_nullable : bool, default False
Ignore nullable option when comparing schemas.
allow_nan_equality : bool, default False
When True, treats two NaN values as equal.
ignore_column_order : bool, default False
When True, sorts columns before comparing.
ignore_row_order : bool, default False
When True, sorts all rows before comparing.
transforms : callable
Additional transforms to make to DataFrame before comparison.

"""
# Apply row and column order transforms + custom transforms.
if transforms is None:
transforms = []
if ignore_column_order:
transforms.append(lambda df: df.select(sorted(df.columns)))
if ignore_row_order:
transforms.append(lambda df: df.sort(df.columns))

df1 = reduce(lambda acc, fn: fn(acc), transforms, df1)
df2 = reduce(lambda acc, fn: fn(acc), transforms, df2)
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
if allow_nan_equality:
assert_generic_rows_equality(df1, df2, are_rows_equal_enhanced, [True])
else:
assert_basic_rows_equality(df1, df2)


def are_dfs_equal(df1, df2):
if df1.schema != df2.schema:
return False
if df1.collect() != df2.collect():
return False
return True


def assert_approx_df_equality(df1, df2, precision, ignore_nullable=False):
# Check schema and row equality.
assert_schema_equality(df1.schema, df2.schema, ignore_nullable)
assert_generic_rows_equality(df1, df2, are_rows_approx_equal, [precision])


def assert_generic_rows_equality(df1, df2, row_equality_fun, row_equality_fun_args):
df1_rows = df1.collect()
df2_rows = df2.collect()
zipped = list(six.moves.zip_longest(df1_rows, df2_rows))
t = PrettyTable(["df1", "df2"])
allRowsEqual = True
for r1, r2 in zipped:
# rows are not equal when one is None and the other isn't
if (r1 is not None and r2 is None) or (r2 is not None and r1 is None):
allRowsEqual = False
t.add_row([r1, r2])
# rows are equal
elif row_equality_fun(r1, r2, *row_equality_fun_args):
first = bcolors.LightBlue + str(r1) + bcolors.LightRed
second = bcolors.LightBlue + str(r2) + bcolors.LightRed
t.add_row([first, second])
# otherwise, rows aren't equal
else:
allRowsEqual = False
t.add_row([r1, r2])
if allRowsEqual == False:
raise DataFramesNotEqualError("\n" + t.get_string())


def assert_basic_rows_equality(df1, df2):
rows1 = df1.collect()
rows2 = df2.collect()
if rows1 != rows2:
t = PrettyTable(["df1", "df2"])
zipped = list(six.moves.zip_longest(rows1, rows2))
for r1, r2 in zipped:
if r1 == r2:
t.add_row([blue(r1), blue(r2)])
else:
t.add_row([r1, r2])
raise DataFramesNotEqualError("\n" + t.get_string())
assert_rows_equality(df1, df2, precision, allow_nan_equality)
27 changes: 25 additions & 2 deletions chispa/number_helpers.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import math
from typing import Optional
MrPowers marked this conversation as resolved.
Show resolved Hide resolved


def isnan(x):
Expand All @@ -8,5 +9,27 @@ def isnan(x):
return False


def nan_safe_equality(x, y) -> bool:
return (x == y) or (isnan(x) and isnan(y))
def check_equal(
x, y,
precision: Optional[float] = None,
allow_nan_equality: bool = False,
) -> bool:
"""Return True if x and y are equal.

Parameters
----------
precision : float, optional
Absolute tolerance when checking for equality.
allow_nan_equality: bool, defaults to False
When True, treats two NaN values as equal.

"""
both_floats = (isinstance(x, float) & isinstance(y, float))
Copy link
Owner

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think we need to do this cause we know the types from the schema.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hi @MrPowers , can you have a look at my draft pull request here to see if I'm going along the right lines with your 3rd objective?

To summarise, rather than the following line in check_equal:

both_floats = (isinstance(x, float) & isinstance(y, float))

I'm doing this:

is_float_type = (dtype_name in ['float', 'double', 'decimal'])

Where dtype gets passed into the function, and is created with the following line:

dtypes = [field.dataType.typeName() for field in df1.schema]

Because we've already compared the schemas, we know that they're equal so we just use df1.schema.

Will using isinstance really have a noticeable effect on the speed of the tests? I feel like it might be the safer option.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I just did some further testing too, and my current solution would not work for pyspark columns with DecimalType, as the python type is decimal.Decimal not float. I'd need to change that if the original implementation is selected.

Copy link
Contributor Author

@mitches-got-glitches mitches-got-glitches May 17, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Any time to give me a steer this week @MrPowers ? As in point me in the right direction for the solution (sorry realised giving a steer might be a localised phrase) - where are you from anyway?

if (precision is not None) & both_floats:
both_equal = abs(x - y) < precision
else:
both_equal = (x == y)

both_nan = (isnan(x) and isnan(y)) if allow_nan_equality else False

return both_equal or both_nan
Loading