Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added new parameter return_bool to validate dataframe methods #265

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
52 changes: 37 additions & 15 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
from __future__ import annotations

import copy
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Union

if TYPE_CHECKING:
from pyspark.sql import DataFrame
Expand All @@ -33,29 +33,37 @@ class DataFrameProhibitedColumnError(ValueError):
"""Raise this when a DataFrame includes prohibited columns."""


def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -> None:
def validate_presence_of_columns(df: DataFrame, required_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate the presence of column names in a DataFrame.

:param df: A spark DataFrame.
:type df: DataFrame`
:type df: DataFrame
:param required_col_names: List of the required column names for the DataFrame.
:type required_col_names: :py:class:`list` of :py:class:`str`
:return: None.
:type required_col_names: list[str]
:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameMissingColumnError: if any of the requested column names are
not present in the DataFrame.
not present in the DataFrame and return_bool is False.
"""
all_col_names = df.columns
missing_col_names = [x for x in required_col_names if x not in all_col_names]
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"

if missing_col_names:
error_message = f"The {missing_col_names} columns are not included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameMissingColumnError(error_message)

return True if return_bool else None


def validate_schema(
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False,
) -> None:
return_bool: bool = False,
) -> Union[None, bool]:
"""Function that validate if a given DataFrame has a given StructType as its schema.

:param df: DataFrame to validate
Expand All @@ -65,9 +73,11 @@ def validate_schema(
:param ignore_nullable: (Optional) A flag for if nullable fields should be
ignored during validation
:type ignore_nullable: bool, optional

:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameMissingStructFieldError: if any StructFields from the required
schema are not included in the DataFrame schema
schema are not included in the DataFrame schema and return_bool is False.
"""
_all_struct_fields = copy.deepcopy(df.schema)
_required_schema = copy.deepcopy(required_schema)
Expand All @@ -80,22 +90,34 @@ def validate_schema(
x.nullable = None

missing_struct_fields = [x for x in _required_schema if x not in _all_struct_fields]
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"


if missing_struct_fields:
error_message = f"The {missing_struct_fields} StructFields are not included in the DataFrame with the following StructFields {_all_struct_fields}"
if return_bool:
return False
raise DataFrameMissingStructFieldError(error_message)

return True if return_bool else None


def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str]) -> None:
def validate_absence_of_columns(df: DataFrame, prohibited_col_names: list[str], return_bool: bool = False) -> Union[None, bool]:
"""Validate that none of the prohibited column names are present among specified DataFrame columns.

:param df: DataFrame containing columns to be checked.
:param prohibited_col_names: List of prohibited column names.
:param return_bool: If True, return a boolean instead of raising an exception.
:type return_bool: bool
:return: None if return_bool is False, otherwise a boolean indicating if validation passed.
:raises DataFrameProhibitedColumnError: If the prohibited column names are
present among the specified DataFrame columns.
present among the specified DataFrame columns and return_bool is False.
"""
all_col_names = df.columns
extra_col_names = [x for x in all_col_names if x in prohibited_col_names]
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"

if extra_col_names:
error_message = f"The {extra_col_names} columns are not allowed to be included in the DataFrame with the following columns {all_col_names}"
if return_bool:
return False
raise DataFrameProhibitedColumnError(error_message)

return True if return_bool else None
74 changes: 61 additions & 13 deletions tests/test_dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,24 +7,36 @@


def describe_validate_presence_of_columns():
def it_raises_if_a_required_column_is_missing():
def it_raises_if_a_required_column_is_missing_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
with pytest.raises(quinn.DataFrameMissingColumnError) as excinfo:
quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"])
quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], False)
assert (
excinfo.value.args[0]
== "The ['fun'] columns are not included in the DataFrame with the following columns ['name', 'age']"
)

def it_does_nothing_if_all_required_columns_are_present():
def it_does_nothing_if_all_required_columns_are_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_presence_of_columns(source_df, ["name"])
quinn.validate_presence_of_columns(source_df, ["name"], False)

def it_returns_false_if_a_required_column_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_presence_of_columns(source_df, ["name", "age", "fun"], True)
assert result is False

def it_returns_true_if_all_required_columns_are_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_presence_of_columns(source_df, ["name"], True)
assert result is True


def describe_validate_schema():
def it_raises_when_struct_field_is_missing1():
def it_raises_when_struct_field_is_missing_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
Expand All @@ -34,7 +46,7 @@ def it_raises_when_struct_field_is_missing1():
]
)
with pytest.raises(quinn.DataFrameMissingStructFieldError) as excinfo:
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(source_df, required_schema, return_bool = False)

current_spark_version = semver.Version.parse(spark.version)
spark_330 = semver.Version.parse("3.3.0")
Expand All @@ -44,7 +56,7 @@ def it_raises_when_struct_field_is_missing1():
expected_error_message = "The [StructField(city,StringType,true)] StructFields are not included in the DataFrame with the following StructFields StructType(List(StructField(name,StringType,true),StructField(age,LongType,true)))" # noqa
assert excinfo.value.args[0] == expected_error_message

def it_does_nothing_when_the_schema_matches():
def it_does_nothing_when_the_schema_matches_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
Expand All @@ -53,7 +65,31 @@ def it_does_nothing_when_the_schema_matches():
StructField("age", LongType(), True),
]
)
quinn.validate_schema(source_df, required_schema)
quinn.validate_schema(source_df, required_schema, return_bool = False)

def it_returns_false_when_struct_field_is_missing_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
[
StructField("name", StringType(), True),
StructField("city", StringType(), True),
]
)
result = quinn.validate_schema(source_df, required_schema, return_bool = True)
assert result is False

def it_returns_true_when_the_schema_matches_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
required_schema = StructType(
[
StructField("name", StringType(), True),
StructField("age", LongType(), True),
]
)
result = quinn.validate_schema(source_df, required_schema, return_bool = True)
assert result is True

def nullable_column_mismatches_are_ignored():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
Expand All @@ -64,21 +100,33 @@ def nullable_column_mismatches_are_ignored():
StructField("age", LongType(), False),
]
)
quinn.validate_schema(source_df, required_schema, ignore_nullable=True)
quinn.validate_schema(source_df, required_schema, ignore_nullable=True, return_bool = False)


def describe_validate_absence_of_columns():
def it_raises_when_a_unallowed_column_is_present():
def it_raises_when_a_unallowed_column_is_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
with pytest.raises(quinn.DataFrameProhibitedColumnError) as excinfo:
quinn.validate_absence_of_columns(source_df, ["age", "cool"])
quinn.validate_absence_of_columns(source_df, ["age", "cool"], False)
assert (
excinfo.value.args[0]
== "The ['age'] columns are not allowed to be included in the DataFrame with the following columns ['name', 'age']" # noqa
)

def it_does_nothing_when_no_unallowed_columns_are_present():
def it_does_nothing_when_no_unallowed_columns_are_present_and_return_bool_is_false():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"], False)

def it_returns_false_when_a_unallowed_column_is_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
result = quinn.validate_absence_of_columns(source_df, ["age", "cool"], True)
assert result is False

def it_returns_true_when_no_unallowed_columns_are_present_and_return_bool_is_true():
data = [("jose", 1), ("li", 2), ("luisa", 3)]
source_df = spark.createDataFrame(data, ["name", "age"])
quinn.validate_absence_of_columns(source_df, ["favorite_color"])
result = quinn.validate_absence_of_columns(source_df, ["favorite_color"], True)
assert result is True
Loading