Skip to content

Commit

Permalink
Revert "Sort struct columns"
Browse files Browse the repository at this point in the history
  • Loading branch information
jeffbrennan authored Nov 5, 2023
1 parent ad271eb commit 60c60c9
Show file tree
Hide file tree
Showing 3 changed files with 25 additions and 649 deletions.
4 changes: 2 additions & 2 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -230,10 +230,10 @@ Converts all the column names in a DataFrame to snake_case. It's annoying to wri
**sort_columns()**

```python
quinn.sort_columns(df=source_df, sort_order="asc", sort_nested=True)
quinn.sort_columns(source_df, "asc")
```

Sorts the DataFrame columns in alphabetical order, including nested columns if sort_nested is set to True. Wide DataFrames are easier to navigate when they're sorted alphabetically.
Sorts the DataFrame columns in alphabetical order. Wide DataFrames are easier to navigate when they're sorted alphabetically.

### DataFrame Helpers

Expand Down
165 changes: 18 additions & 147 deletions quinn/transformations.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,10 @@
import re
import pyspark.sql.functions as F # noqa: N812
from __future__ import annotations
from collections.abc import Callable

import pyspark.sql.functions as F # noqa: N812
from pyspark.sql import DataFrame
from pyspark.sql.types import ArrayType, MapType, StructField, StructType
from pyspark.sql.types import ArrayType, MapType, StructType

from quinn.schema_helpers import complex_fields


Expand Down Expand Up @@ -83,161 +84,30 @@ def to_snake_case(s: str) -> str:
return s.lower().replace(" ", "_")


def sort_columns(
df: DataFrame, sort_order: str, sort_nested: bool = False
) -> DataFrame:
"""This function sorts the columns of a given DataFrame based on a given sort
order. The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
def sort_columns(df: DataFrame, sort_order: str) -> DataFrame:
"""Function sorts the columns of a given DataFrame based on a given sort order.
The ``sort_order`` parameter can either be ``asc`` or ``desc``, which correspond to
ascending and descending order, respectively. If any other value is provided for
the ``sort_order`` parameter, a ``ValueError`` will be raised.
:param df: A DataFrame
:type df: pyspark.sql.DataFrame
:type df: pandas.DataFrame
:param sort_order: The order in which to sort the columns in the DataFrame
:type sort_order: str
:param sort_nested: Whether to sort nested structs or not. Defaults to false.
:type sort_nested: bool
:return: A DataFrame with the columns sorted in the chosen order
:rtype: pyspark.sql.DataFrame
:rtype: pandas.DataFrame
"""

def sort_nested_cols(schema, is_reversed, base_field="") -> list[str]:
# recursively check nested fields and sort them
# https://stackoverflow.com/questions/57821538/how-to-sort-columns-of-nested-structs-alphabetically-in-pyspark
# Credits: @pault for logic

def parse_fields(
fields_to_sort: list, parent_struct, is_reversed: bool
) -> list:
sorted_fields: list = sorted(
fields_to_sort,
key=lambda x: x["name"],
reverse=is_reversed,
)

results = []
for field in sorted_fields:
new_struct = StructType([StructField.fromJson(field)])
new_base_field = parent_struct.name
if base_field:
new_base_field = base_field + "." + new_base_field

results.extend(
sort_nested_cols(new_struct, is_reversed, base_field=new_base_field)
)
return results

select_cols = []
for parent_struct in sorted(schema, key=lambda x: x.name, reverse=is_reversed):
field_type = parent_struct.dataType
if isinstance(field_type, ArrayType):
array_parent = parent_struct.jsonValue()["type"]["elementType"]
base_str = f"transform({parent_struct.name}"
suffix_str = f") AS {parent_struct.name}"

# if struct in array, create mapping to struct
if array_parent["type"] == "struct":
array_parent = array_parent["fields"]
base_str = f"{base_str}, x -> struct("
suffix_str = f"){suffix_str}"

array_elements = parse_fields(array_parent, parent_struct, is_reversed)
element_names = [i.split(".")[-1] for i in array_elements]
array_elements_formatted = [f"x.{i} as {i}" for i in element_names]

# create a string representation of the sorted array
# ex: transform(phone_numbers, x -> struct(x.number as number, x.type as type)) AS phone_numbers
result = f"{base_str}{', '.join(array_elements_formatted)}{suffix_str}"

elif isinstance(field_type, StructType):
field_list = parent_struct.jsonValue()["type"]["fields"]
sub_fields = parse_fields(field_list, parent_struct, is_reversed)

# create a string representation of the sorted struct
# ex: struct(address.zip.first5, address.zip.last4) AS zip
result = f"struct({', '.join(sub_fields)}) AS {parent_struct.name}"

else:
if base_field:
result = f"{base_field}.{parent_struct.name}"
else:
result = parent_struct.name
select_cols.append(result)

return select_cols

def get_original_nullability(field: StructField, result_dict: dict) -> None:
if hasattr(field, "nullable"):
result_dict[field.name] = field.nullable
else:
result_dict[field.name] = True

if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
):
return

if isinstance(field.dataType, ArrayType):
result_dict[f"{field.name}_element"] = field.dataType.containsNull
children = field.dataType.elementType.fields
else:
children = field.dataType.fields
for i in children:
get_original_nullability(i, result_dict)

def fix_nullability(field: StructField, result_dict: dict) -> None:
field.nullable = result_dict[field.name]
if not isinstance(field.dataType, StructType) and not isinstance(
field.dataType, ArrayType
):
return

if isinstance(field.dataType, ArrayType):
# save the containsNull property of the ArrayType
field.dataType.containsNull = result_dict[f"{field.name}_element"]
children = field.dataType.elementType.fields
else:
children = field.dataType.fields

for i in children:
fix_nullability(i, result_dict)

if sort_order not in ["asc", "desc"]:
sorted_col_names = None
if sort_order == "asc":
sorted_col_names = sorted(df.columns)
elif sort_order == "desc":
sorted_col_names = sorted(df.columns, reverse=True)
else:
msg = f"['asc', 'desc'] are the only valid sort orders and you entered a sort order of '{sort_order}'"
raise ValueError(
msg,
)
reverse_lookup = {
"asc": False,
"desc": True,
}

is_reversed: bool = reverse_lookup[sort_order]
top_level_sorted_df = df.select(*sorted(df.columns, reverse=is_reversed))
if not sort_nested:
return top_level_sorted_df

is_nested: bool = any(
[
isinstance(i.dataType, StructType) or isinstance(i.dataType, ArrayType)
for i in top_level_sorted_df.schema
]
)

if not is_nested:
return top_level_sorted_df

fully_sorted_schema = sort_nested_cols(top_level_sorted_df.schema, is_reversed)
output = df.selectExpr(fully_sorted_schema)
result_dict = {}
for field in df.schema:
get_original_nullability(field, result_dict)

for field in output.schema:
fix_nullability(field, result_dict)

final_df = output.sparkSession.createDataFrame(output.rdd, output.schema)
return final_df
return df.select(*sorted_col_names)


Expand Down Expand Up @@ -380,4 +250,5 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
]
df = df.toDF(*sanitized_columns) # noqa: PD901

return df
return df

Loading

0 comments on commit 60c60c9

Please sign in to comment.