diff --git a/README.md b/README.md index 62d96b8d..af642710 100644 --- a/README.md +++ b/README.md @@ -230,10 +230,10 @@ Converts all the column names in a DataFrame to snake_case. It's annoying to wri **sort_columns()** ```python -quinn.sort_columns(df=source_df, sort_order="asc", sort_nested=True) +quinn.sort_columns(source_df, "asc") ``` -Sorts the DataFrame columns in alphabetical order, including nested columns if sort_nested is set to True. Wide DataFrames are easier to navigate when they're sorted alphabetically. +Sorts the DataFrame columns in alphabetical order. Wide DataFrames are easier to navigate when they're sorted alphabetically. ### DataFrame Helpers diff --git a/quinn/transformations.py b/quinn/transformations.py index 31f59212..7b64faa3 100644 --- a/quinn/transformations.py +++ b/quinn/transformations.py @@ -1,9 +1,10 @@ import re -import pyspark.sql.functions as F # noqa: N812 -from __future__ import annotations from collections.abc import Callable + +import pyspark.sql.functions as F # noqa: N812 from pyspark.sql import DataFrame -from pyspark.sql.types import ArrayType, MapType, StructField, StructType +from pyspark.sql.types import ArrayType, MapType, StructType + from quinn.schema_helpers import complex_fields @@ -83,161 +84,30 @@ def to_snake_case(s: str) -> str: return s.lower().replace(" ", "_") -def sort_columns( - df: DataFrame, sort_order: str, sort_nested: bool = False -) -> DataFrame: - """This function sorts the columns of a given DataFrame based on a given sort - order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to +def sort_columns(df: DataFrame, sort_order: str) -> DataFrame: + """Function sorts the columns of a given DataFrame based on a given sort order. + + The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to ascending and descending order, respectively. If any other value is provided for the ``sort_order`` parameter, a ``ValueError`` will be raised. :param df: A DataFrame - :type df: pyspark.sql.DataFrame + :type df: pandas.DataFrame :param sort_order: The order in which to sort the columns in the DataFrame :type sort_order: str - :param sort_nested: Whether to sort nested structs or not. Defaults to false. - :type sort_nested: bool :return: A DataFrame with the columns sorted in the chosen order - :rtype: pyspark.sql.DataFrame + :rtype: pandas.DataFrame """ - - def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]: - # recursively check nested fields and sort them - # https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark - # Credits: @pault for logic - - def parse_fields( - fields_to_sort: list, parent_struct, is_reversed: bool - ) -> list: - sorted_fields: list = sorted( - fields_to_sort, - key=lambda x: x["name"], - reverse=is_reversed, - ) - - results = [] - for field in sorted_fields: - new_struct = StructType([StructField.fromJson(field)]) - new_base_field = parent_struct.name - if base_field: - new_base_field = base_field + "." + new_base_field - - results.extend( - sort_nested_cols(new_struct, is_reversed, base_field=new_base_field) - ) - return results - - select_cols = [] - for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed): - field_type = parent_struct.dataType - if isinstance(field_type, ArrayType): - array_parent = parent_struct.jsonValue()["type"]["elementType"] - base_str = f"transform({parent_struct.name}" - suffix_str = f") AS {parent_struct.name}" - - # if struct in array, create mapping to struct - if array_parent["type"] == "struct": - array_parent = array_parent["fields"] - base_str = f"{base_str}, x -> struct(" - suffix_str = f"){suffix_str}" - - array_elements = parse_fields(array_parent, parent_struct, is_reversed) - element_names = [i.split(".")[-1] for i in array_elements] - array_elements_formatted = [f"x.{i} as {i}" for i in element_names] - - # create a string representation of the sorted array - # ex: transform(phone_numbers, x -> struct(x.number as number, x.type as type)) AS phone_numbers - result = f"{base_str}{', '.join(array_elements_formatted)}{suffix_str}" - - elif isinstance(field_type, StructType): - field_list = parent_struct.jsonValue()["type"]["fields"] - sub_fields = parse_fields(field_list, parent_struct, is_reversed) - - # create a string representation of the sorted struct - # ex: struct(address.zip.first5, address.zip.last4) AS zip - result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}" - - else: - if base_field: - result = f"{base_field}.{parent_struct.name}" - else: - result = parent_struct.name - select_cols.append(result) - - return select_cols - - def get_original_nullability(field: StructField, result_dict: dict) -> None: - if hasattr(field, "nullable"): - result_dict[field.name] = field.nullable - else: - result_dict[field.name] = True - - if not isinstance(field.dataType, StructType) and not isinstance( - field.dataType, ArrayType - ): - return - - if isinstance(field.dataType, ArrayType): - result_dict[f"{field.name}_element"] = field.dataType.containsNull - children = field.dataType.elementType.fields - else: - children = field.dataType.fields - for i in children: - get_original_nullability(i, result_dict) - - def fix_nullability(field: StructField, result_dict: dict) -> None: - field.nullable = result_dict[field.name] - if not isinstance(field.dataType, StructType) and not isinstance( - field.dataType, ArrayType - ): - return - - if isinstance(field.dataType, ArrayType): - # save the containsNull property of the ArrayType - field.dataType.containsNull = result_dict[f"{field.name}_element"] - children = field.dataType.elementType.fields - else: - children = field.dataType.fields - - for i in children: - fix_nullability(i, result_dict) - - if sort_order not in ["asc", "desc"]: + sorted_col_names = None + if sort_order == "asc": + sorted_col_names = sorted(df.columns) + elif sort_order == "desc": + sorted_col_names = sorted(df.columns, reverse=True) + else: msg = f"['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'" raise ValueError( msg, ) - reverse_lookup = { - "asc": False, - "desc": True, - } - - is_reversed: bool = reverse_lookup[sort_order] - top_level_sorted_df = df.select(*sorted(df.columns, reverse=is_reversed)) - if not sort_nested: - return top_level_sorted_df - - is_nested: bool = any( - [ - isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType) - for i in top_level_sorted_df.schema - ] - ) - - if not is_nested: - return top_level_sorted_df - - fully_sorted_schema = sort_nested_cols(top_level_sorted_df.schema, is_reversed) - output = df.selectExpr(fully_sorted_schema) - result_dict = {} - for field in df.schema: - get_original_nullability(field, result_dict) - - for field in output.schema: - fix_nullability(field, result_dict) - - final_df = output.sparkSession.createDataFrame(output.rdd, output.schema) - return final_df return df.select(*sorted_col_names) @@ -380,4 +250,5 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame: ] df = df.toDF(*sanitized_columns) # noqa: PD901 - return df \ No newline at end of file + return df + diff --git a/tests/test_transformations.py b/tests/test_transformations.py index 83c4d904..0346ce5d 100644 --- a/tests/test_transformations.py +++ b/tests/test_transformations.py @@ -1,10 +1,10 @@ import pytest -import chispa +from pyspark.sql.types import StructType, StructField, StringType, IntegerType, MapType + import quinn -from pyspark.sql import DataFrame -from pyspark.sql.types import StructType, StructField, StringType, IntegerType, ArrayType, MapType -from quinn.transformations import flatten_struct, flatten_map, flatten_dataframe from tests.conftest import auto_inject_fixtures +import chispa +from quinn.transformations import flatten_struct, flatten_map, flatten_dataframe @auto_inject_fixtures("spark") @@ -224,501 +224,6 @@ def it_throws_an_error_if_the_sort_order_is_invalid(spark): == "['asc', 'desc'] are the only valid sort orders and you entered a sort order of 'cats'" ) -def _test_sort_struct_flat(spark, sort_order: str): - def _get_simple_test_dataframes(sort_order) -> tuple[(DataFrame, DataFrame)]: - col_a = 1 - col_b = 2 - col_c = 3 - - unsorted_fields = StructType( - [ - StructField("b", IntegerType()), - StructField("c", IntegerType()), - StructField("a", IntegerType()), - ] - ) - unsorted_data = [ - (col_b, col_c, col_a), - ] - if sort_order == "asc": - sorted_fields = StructType( - [ - StructField("a", IntegerType()), - StructField("b", IntegerType()), - StructField("c", IntegerType()), - ] - ) - - sorted_data = [ - (col_a, col_b, col_c), - ] - elif sort_order == "desc": - sorted_fields = StructType( - [ - StructField("c", IntegerType()), - StructField("b", IntegerType()), - StructField("a", IntegerType()), - ] - ) - - sorted_data = [ - (col_c, col_b, col_a), - ] - else: - raise ValueError( - "['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'".format( - sort_order=sort_order - ) - ) - - unsorted_df = spark.createDataFrame(unsorted_data, unsorted_fields) - expected_df = spark.createDataFrame(sorted_data, sorted_fields) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_simple_test_dataframes(sort_order=sort_order) - sorted_df = quinn.sort_columns(unsorted_df, sort_order) - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=True - ) - - -def test_sort_struct_flat(spark): - _test_sort_struct_flat(spark, "asc") - - -def test_sort_struct_flat_desc(spark): - _test_sort_struct_flat(spark, "desc") - - -def _get_test_dataframes_schemas() -> dict: - elements = { - "_id": (StructField("_id", StringType(), nullable=False)), - "first_name": (StructField("first_name", StringType(), nullable=False)), - "city": (StructField("city", StringType(), nullable=False)), - "last4": (StructField("last4", IntegerType(), nullable=True)), - "first5": (StructField("first5", IntegerType(), nullable=True)), - "type": (StructField("type", StringType(), nullable=True)), - "number": (StructField("number", StringType(), nullable=True)), - } - - return elements - - -def _get_test_dataframes_data() -> tuple[(str, str, int, int, str)]: - _id = "12345" - city = "Fake City" - zip_first5 = 54321 - zip_last4 = 12345 - first_name = "John" - - return _id, city, zip_first5, zip_last4, first_name - - -def _get_unsorted_nested_struct_fields(elements: dict): - unsorted_fields = [ - elements["_id"], - elements["first_name"], - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType([elements["last4"], elements["first5"]]), - nullable=True, - ), - elements["city"], - ] - ), - nullable=True, - ), - ] - return unsorted_fields - - -def _test_sort_struct_nested(spark, ignore_nullable: bool): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - elements = _get_test_dataframes_schemas() - unsorted_fields = _get_unsorted_nested_struct_fields(elements) - sorted_fields = [ - elements["_id"], - StructField( - "address", - StructType( - [ - elements["city"], - StructField( - "zip", - StructType([elements["first5"], elements["last4"]]), - nullable=True, - ), - ] - ), - nullable=True, - ), - elements["first_name"], - ] - - _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - unsorted_data = [ - (_id, first_name, (((zip_last4, zip_first5)), city)), - (_id, first_name, (((None, zip_first5)), city)), - (_id, first_name, (None)), - ] - - sorted_data = [ - (_id, ((city, (zip_first5, zip_last4))), first_name), - (_id, ((city, (zip_first5, None))), first_name), - (_id, (None), first_name), - ] - - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) - expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable - ) - - -def _test_sort_struct_nested_desc(spark, ignore_nullable: bool): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - elements = _get_test_dataframes_schemas() - unsorted_fields = _get_unsorted_nested_struct_fields(elements) - - sorted_fields = [ - elements["first_name"], - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType([elements["last4"], elements["first5"]]), - nullable=True, - ), - elements["city"], - ] - ), - nullable=True, - ), - elements["_id"], - ] - - _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - - unsorted_data = [(_id, first_name, (((zip_last4, zip_first5)), city))] - sorted_data = [ - ( - first_name, - ((zip_first5, zip_last4), city), - _id, - ) - ] - - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) - expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "desc") - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable - ) - - -def _get_unsorted_nested_array_fields(elements: dict) -> list: - unsorted_fields = [ - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType( - [ - elements["last4"], - elements["first5"], - ] - ), - nullable=False, - ), - elements["city"], - ] - ), - nullable=False, - ), - StructField( - "phone_numbers", - ArrayType(StructType([elements["type"], elements["number"]])), - nullable=True, - ), - elements["_id"], - elements["first_name"], - ] - return unsorted_fields - - -def _test_sort_struct_nested_with_arraytypes(spark, ignore_nullable: bool): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - elements = _get_test_dataframes_schemas() - unsorted_fields = _get_unsorted_nested_array_fields(elements) - - sorted_fields = [ - elements["_id"], - StructField( - "address", - StructType( - [ - elements["city"], - StructField( - "zip", - StructType([elements["first5"], elements["last4"]]), - nullable=False, - ), - ] - ), - nullable=False, - ), - elements["first_name"], - StructField( - "phone_numbers", - ArrayType(StructType([elements["number"], elements["type"]])), - nullable=True, - ), - ] - - _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - phone_type = "home" - phone_number = "555-555-5555" - - unsorted_data = [ - ( - ((zip_last4, zip_first5), city), - [(phone_type, phone_number)], - _id, - first_name, - ), - (((zip_last4, zip_first5), city), [(phone_type, None)], _id, first_name), - (((None, None), city), None, _id, first_name), - ] - sorted_data = [ - ( - _id, - (city, (zip_last4, zip_first5)), - first_name, - [(phone_type, phone_number)], - ), - (_id, (city, (zip_last4, zip_first5)), first_name, [(phone_type, None)]), - (_id, (city, (None, None)), first_name, None), - ] - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) - expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "asc", sort_nested=True) - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable - ) - - -def _test_sort_struct_nested_with_arraytypes_desc(spark, ignore_nullable: bool): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - elements = _get_test_dataframes_schemas() - unsorted_fields = _get_unsorted_nested_array_fields(elements) - - sorted_fields = [ - StructField( - "phone_numbers", - ArrayType(StructType([elements["type"], elements["number"]])), - nullable=True, - ), - elements["first_name"], - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType([elements["last4"], elements["first5"]]), - nullable=False, - ), - elements["city"], - ] - ), - nullable=False, - ), - elements["_id"], - ] - - _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - phone_type = "home" - phone_number = "555-555-5555" - - unsorted_data = [ - ( - ((zip_last4, zip_first5), city), - [(phone_type, phone_number)], - _id, - first_name, - ), - ] - sorted_data = [ - ( - [(phone_type, phone_number)], - first_name, - ((zip_last4, zip_first5), city), - _id, - ), - ] - - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) - expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested=True) - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable - ) - - -def _test_sort_struct_nested_in_arraytypes(spark, ignore_nullable: bool): - def _get_test_dataframes() -> tuple[(DataFrame, DataFrame)]: - elements = _get_test_dataframes_schemas() - unsorted_fields = _get_unsorted_nested_array_fields(elements) - - # extensions = StructType( - # [ - # StructField("extension_code", StringType(), nullable=True), - # StructField( - # "extension_numbers", - # StructType( - # [ - # StructField("extension_number_one", IntegerType()), - # StructField("extension_number_two", IntegerType()), - # ] - # ), - # ), - # ] - # ) - - sorted_fields = [ - StructField( - "phone_numbers", - ArrayType(StructType([elements["type"], elements["number"]])), - ), - StructField( - "extensions", - ArrayType( - StructType( - [ - StructField("extension_number_one", IntegerType()), - StructField("extension_number_two", IntegerType()), - ] - ), - StructField("extension_code", StringType(), nullable=True), - ), - ), - elements["first_name"], - StructField( - "address", - StructType( - [ - StructField( - "zip", - StructType([elements["last4"], elements["first5"]]), - nullable=False, - ), - elements["city"], - ] - ), - nullable=False, - ), - elements["_id"], - ] - - _id, city, zip_first5, zip_last4, first_name = _get_test_dataframes_data() - phone_type = "home" - phone_number = "555-555-5555" - extension_code = "test" - extension_number_one = 1 - extension_number_two = 2 - - unsorted_data = [ - ( - ((zip_last4, zip_first5), city), - [(phone_type, phone_number)], - _id, - first_name, - ), - ] - sorted_data = [ - ( - [(phone_type, phone_number)], - [(extension_number_one, extension_number_two), extension_code], - first_name, - ((zip_last4, zip_first5), city), - _id, - ), - ] - - expected_df = spark.createDataFrame(sorted_data, StructType(sorted_fields)) - unsorted_df = spark.createDataFrame(unsorted_data, StructType(unsorted_fields)) - - return unsorted_df, expected_df - - unsorted_df, expected_df = _get_test_dataframes() - sorted_df = quinn.sort_columns(unsorted_df, "desc", sort_nested=True) - - chispa.schema_comparer.assert_schema_equality( - sorted_df.schema, expected_df.schema, ignore_nullable=ignore_nullable - ) - - -def test_sort_struct_nested(spark): - _test_sort_struct_nested(spark, True) - - -def test_sort_struct_nested_desc(spark): - _test_sort_struct_nested_desc(spark, True) - - -def test_sort_struct_nested_with_arraytypes(spark): - _test_sort_struct_nested_with_arraytypes(spark, True) - - -def test_sort_struct_nested_with_arraytypes_desc(spark): - _test_sort_struct_nested_with_arraytypes_desc(spark, True) - - -def test_sort_struct_nested_nullable(spark): - _test_sort_struct_nested(spark, False) - - -def test_sort_struct_nested_nullable_desc(spark): - _test_sort_struct_nested_desc(spark, False) - - -def test_sort_struct_nested_with_arraytypes_nullable(spark): - _test_sort_struct_nested_with_arraytypes(spark, False) - - -def test_sort_struct_nested_with_arraytypes_nullable_desc(spark): - _test_sort_struct_nested_with_arraytypes_desc(spark, False) def test_flatten_struct(spark): data = [ @@ -848,4 +353,4 @@ def test_flatten_dataframe(spark): ) expected_df = spark.createDataFrame(expected_data, expected_schema) result_df = flatten_dataframe(df) - chispa.assert_df_equality(result_df, expected_df) \ No newline at end of file + chispa.assert_df_equality(result_df, expected_df)