Skip to content

Commit

Permalink
Refactor flatten DataFrame code
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPowers committed Oct 8, 2023
1 parent 6a7c3ef commit 92932e7
Show file tree
Hide file tree
Showing 8 changed files with 342 additions and 317 deletions.
2 changes: 1 addition & 1 deletion quinn/extensions/column_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def isNullOrBlank(self: Column) -> Column:
blank characters, or ``False`` otherwise.
:rtype: Column
"""
return (self.isNull()) | (trim(self) == "")
return (self.isNull()) | (trim(self) == "") # noqa: PLC1901

Check failure on line 66 in quinn/extensions/column_ext.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF100)

quinn/extensions/column_ext.py:66:50: RUF100 Unused `noqa` directive (non-enabled: `PLC1901`)


def isNotIn(self: Column, _list: list[Any]) -> Column:
Expand Down
167 changes: 3 additions & 164 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,23 +3,22 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from collections.abc import Callable
from numbers import Number

from pyspark.sql import Column, DataFrame
from pyspark.sql import Column
from pyspark.sql.functions import udf


import re
import uuid
from typing import Any, Callable
from typing import Any

import pyspark.sql.functions as F # noqa: N812
from pyspark.sql.types import (
ArrayType,
BooleanType,
MapType,
StringType,
StructType,
)


Expand Down Expand Up @@ -243,166 +242,6 @@ def regexp_extract_all(s: Column, regexp: Column) -> Column:
return None if s is None else re.findall(regexp, s)


def sanitize_column_name(name: str, replace_char: str = "_") -> str:
"""Sanitizes column names by replacing special characters with the specified character.
:param name: The original column name.
:type name: str
:param replace_char: The character to replace special characters with, defaults to '_'.
:type replace_char: str, optional
:return: The sanitized column name.
:rtype: str
"""
return re.sub(r"[^a-zA-Z0-9_]", replace_char, name)


def _get_complex_fields(df: DataFrame) -> dict[str, object]:
"""Returns a dictionary of complex field names and their data types from the input DataFrame's schema.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:return: A dictionary with complex field names as keys and their respective data types as values.
:rtype: Dict[str, object]
"""
return {
field.name: field.dataType
for field in df.schema.fields
if isinstance(field.dataType, (ArrayType, StructType, MapType))
}


def flatten_struct(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the StructType to be flattened.
:type col_name: str
:param sep: The separator to use in the resulting flattened column names, defaults to ':'.
:type sep: str, optional
:return: The DataFrame with the flattened StructType column.
:rtype: List[Column]
"""
struct_type = _get_complex_fields(df)[col_name]
expanded = [
F.col(f"`{col_name}`.`{k}`").alias(col_name + sep + k)
for k in [n.name for n in struct_type.fields]
]
return df.select("*", *expanded).drop(F.col(f"`{col_name}`"))


def explode_array(df: DataFrame, col_name: str) -> DataFrame:
"""Explodes the specified ArrayType column in the input DataFrame and returns a new DataFrame with the exploded column.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the ArrayType to be exploded.
:type col_name: str
:return: The DataFrame with the exploded ArrayType column.
:rtype: DataFrame
"""
return df.select("*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)).drop(
col_name,
)


def flatten_map(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the MapType to be flattened.
:type col_name: str
:param sep: The separator to use in the resulting flattened column names, defaults to ":".
:type sep: str, optional
:return: The DataFrame with the flattened MapType column.
:rtype: DataFrame
"""
keys_df = df.select(F.explode_outer(F.map_keys(F.col(f"`{col_name}`")))).distinct()
keys = [row[0] for row in keys_df.collect()]
key_cols = [
F.col(f"`{col_name}`").getItem(k).alias(col_name + sep + k) for k in keys
]
return df.select(
[F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols,
)


def flatten_dataframe(
df: DataFrame,
sep: str = ":",
replace_char: str = "_",
sanitized_columns: bool = False, # noqa: FBT001, FBT002
) -> DataFrame:
"""Flattens the complex columns in the DataFrame.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param sep: The separator to use in the resulting flattened column names, defaults to ":".
:type sep: str, optional
:param replace_char: The character to replace special characters with in column names, defaults to "_".
:type replace_char: str, optional
:param sanitized_columns: Whether to sanitize column names, defaults to False.
:type sanitized_columns: bool, optional
:return: The DataFrame with all complex data types flattened.
:rtype: DataFrame
.. note:: This function assumes the input DataFrame has a consistent schema across all rows. If you have files with
different schemas, process each separately instead.
.. example:: Example usage:
>>> data = [
(
1,
("Alice", 25),
{"A": 100, "B": 200},
["apple", "banana"],
{"key": {"nested_key": 10}},
{"A#": 1000, "B@": 2000},
),
(
2,
("Bob", 30),
{"A": 150, "B": 250},
["orange", "grape"],
{"key": {"nested_key": 20}},
{"A#": 1500, "B@": 2500},
),
]
>>> df = spark.createDataFrame(data)
>>> flattened_df = flatten_dataframe(df)
>>> flattened_df.show()
>>> flattened_df_with_hyphen = flatten_dataframe(df, replace_char="-")
>>> flattened_df_with_hyphen.show()
"""
complex_fields = _get_complex_fields(df)

while len(complex_fields) != 0:
col_name = next(iter(complex_fields.keys()))

if isinstance(complex_fields[col_name], StructType):
df = flatten_struct(df, col_name, sep) # noqa: PD901

elif isinstance(complex_fields[col_name], ArrayType):
df = explode_array(df, col_name) # noqa: PD901

elif isinstance(complex_fields[col_name], MapType):
df = flatten_map(df, col_name, sep) # noqa: PD901

complex_fields = _get_complex_fields(df)

# Sanitize column names with the specified replace_char
if sanitized_columns:
sanitized_columns = [
sanitize_column_name(col_name, replace_char) for col_name in df.columns
]
df = df.toDF(*sanitized_columns) # noqa: PD901

return df


def business_days_between(start_date: Column, end_date: Column) -> Column: # noqa: ARG001
"""Function takes two Spark `Columns` and returns a `Column` with the number of business days between the start and the end date.
Expand Down
24 changes: 22 additions & 2 deletions quinn/schema_helpers.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,10 @@
from __future__ import annotations

import json

from pyspark.sql import SparkSession
from pyspark.sql import types as T # noqa: N812
from typing import Union

Check failure on line 7 in quinn/schema_helpers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

quinn/schema_helpers.py:7:20: F401 `typing.Union` imported but unused


def print_schema_as_code(dtype: T.DataType) -> str:

Check failure on line 10 in quinn/schema_helpers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

quinn/schema_helpers.py:1:1: I001 Import block is un-sorted or un-formatted
Expand Down Expand Up @@ -40,8 +43,9 @@ def print_schema_as_code(dtype: T.DataType) -> str:
elif isinstance(dtype, T.DecimalType):
res.append(f"DecimalType({dtype.precision}, {dtype.scale})")

else: # noqa: PLR5501
if str(dtype).endswith("()"): # PySpark 3.3+
else:
# PySpark 3.3+
if str(dtype).endswith("()"): # noqa: PLR5501

Check failure on line 48 in quinn/schema_helpers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (PLR5501)

quinn/schema_helpers.py:46:5: PLR5501 Use `elif` instead of `else` then `if`, to reduce indentation

Check failure on line 48 in quinn/schema_helpers.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF100)

quinn/schema_helpers.py:48:40: RUF100 Unused `noqa` directive (unused: `PLR5501`)
res.append(str(dtype))
else:
res.append(f"{dtype}()")
Expand Down Expand Up @@ -149,3 +153,19 @@ def _convert_nullable(null_str: str) -> bool:
fields.append(field)

return T.StructType(fields=fields)


def complex_fields(schema: T.StructType) -> dict[str, object]:
"""Returns a dictionary of complex field names and their data types from the input DataFrame's schema.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:return: A dictionary with complex field names as keys and their respective data types as values.
:rtype: Dict[str, object]
"""
return {
field.name: field.dataType
for field in schema.fields
if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType))
}

4 changes: 2 additions & 2 deletions quinn/split_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ def _num_delimiter(col_value1: str) -> int:

# If the length of split_value is same as new_col_names, check if any of the split values is None or empty string
elif any( # noqa: RET506
x is None or x.strip() == "" for x in split_value[: len(new_col_names)]
x is None or x.strip() == "" for x in split_value[: len(new_col_names)] # noqa: PLC1901

Check failure on line 71 in quinn/split_columns.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF100)

quinn/split_columns.py:71:90: RUF100 Unused `noqa` directive (non-enabled: `PLC1901`)
):
msg = "Null or empty values are not accepted for columns in strict mode"
raise ValueError(
Expand All @@ -93,7 +93,7 @@ def _num_delimiter(col_value1: str) -> int:
if mode == "strict":
# Create an array of select expressions to create new columns from the split values
select_exprs = [
when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias(
when(split_col_expr.getItem(i) != "", split_col_expr.getItem(i)).alias( # noqa: PLC1901

Check failure on line 96 in quinn/split_columns.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (RUF100)

quinn/split_columns.py:96:86: RUF100 Unused `noqa` directive (non-enabled: `PLC1901`)
new_col_names[i],
)
for i in range(len(new_col_names))
Expand Down
Loading

0 comments on commit 92932e7

Please sign in to comment.