Skip to content

Commit

Permalink
Normalize semi-structured data into tabular tables. (#64)
Browse files Browse the repository at this point in the history
* 1

* add some tests

* refactored

* _  -> :

* more

* updates

* operate on columns names using backticks

* remove function read_and_flatten_nested_files

* Update note

* make get_complex_fields private

* update tests

* update tests

---------

Co-authored-by: Matthew Powers <[email protected]>
  • Loading branch information
bjornjorgensen and MrPowers authored Oct 7, 2023
1 parent 141355e commit 074ce5f
Show file tree
Hide file tree
Showing 2 changed files with 321 additions and 0 deletions.
158 changes: 158 additions & 0 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,9 @@
StringType,
)

from pyspark.sql import Column, DataFrame

Check failure on line 23 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F811)

quinn/functions.py:23:25: F811 Redefinition of unused `Column` from line 8

Check failure on line 23 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (TCH002)

quinn/functions.py:23:25: TCH002 Move third-party import `pyspark.sql.Column` into a type-checking block

Check failure on line 23 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (TCH002)

quinn/functions.py:23:33: TCH002 Move third-party import `pyspark.sql.DataFrame` into a type-checking block
from typing import Dict, List

Check failure on line 24 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

quinn/functions.py:24:26: F401 `typing.List` imported but unused
from pathlib import Path

Check failure on line 25 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (F401)

quinn/functions.py:25:21: F401 `pathlib.Path` imported but unused

def single_space(col: Column) -> Column:

Check failure on line 27 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (I001)

quinn/functions.py:12:1: I001 Import block is un-sorted or un-formatted
"""Function takes a column and replaces all the multiple white spaces with a single space.
Expand Down Expand Up @@ -225,6 +228,161 @@ def array_choice(col: Column, seed: int | None = None) -> Column:


@F.udf(returnType=ArrayType(StringType()))
def regexp_extract_all(s, regexp):

Check failure on line 231 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (ANN201)

quinn/functions.py:231:5: ANN201 Missing return type annotation for public function `regexp_extract_all`

Check failure on line 231 in quinn/functions.py

View workflow job for this annotation

GitHub Actions / ruff

Ruff (D103)

quinn/functions.py:231:5: D103 Missing docstring in public function
return None if s == None else re.findall(regexp, s)


def sanitize_column_name(name: str, replace_char: str = "_") -> str:
"""
Sanitizes column names by replacing special characters with the specified character.
:param name: The original column name.
:type name: str
:param replace_char: The character to replace special characters with, defaults to '_'.
:type replace_char: str, optional
:return: The sanitized column name.
:rtype: str
"""
return re.sub(r"[^a-zA-Z0-9_]", replace_char, name)

def _get_complex_fields(df: DataFrame) -> Dict[str, object]:
"""
Returns a dictionary of complex field names and their data types from the input DataFrame's schema.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:return: A dictionary with complex field names as keys and their respective data types as values.
:rtype: Dict[str, object]
"""
return {
field.name: field.dataType
for field in df.schema.fields
if isinstance(field.dataType, (ArrayType, StructType, MapType))
}

def flatten_struct(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""
Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the StructType to be flattened.
:type col_name: str
:param sep: The separator to use in the resulting flattened column names, defaults to ':'.
:type sep: str, optional
:return: The DataFrame with the flattened StructType column.
:rtype: List[Column]
"""
struct_type = _get_complex_fields(df)[col_name]
expanded = [
F.col(f"`{col_name}`.`{k}`").alias(col_name + sep + k)
for k in [n.name for n in struct_type.fields]
]
return df.select("*", *expanded).drop(F.col(f"`{col_name}`"))

def explode_array(df: DataFrame, col_name: str) -> DataFrame:
"""
Explodes the specified ArrayType column in the input DataFrame and returns a new DataFrame with the exploded column.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the ArrayType to be exploded.
:type col_name: str
:return: The DataFrame with the exploded ArrayType column.
:rtype: DataFrame
"""
return df.select("*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)).drop(col_name)

def flatten_map(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""
Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param col_name: The column name of the MapType to be flattened.
:type col_name: str
:param sep: The separator to use in the resulting flattened column names, defaults to ":".
:type sep: str, optional
:return: The DataFrame with the flattened MapType column.
:rtype: DataFrame
"""
keys_df = df.select(F.explode_outer(F.map_keys(F.col(f"`{col_name}`")))).distinct()
keys = [row[0] for row in keys_df.collect()]
key_cols = [F.col(f"`{col_name}`").getItem(k).alias(col_name + sep + k) for k in keys]
return df.select([F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols)

def flatten_dataframe(df: DataFrame, sep: str = ":", replace_char: str = "_", sanitized_columns: bool = False) -> DataFrame:
"""
Flattens all complex data types (StructType, ArrayType, and MapType) in the input DataFrame and returns a
new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
:param sep: The separator to use in the resulting flattened column names, defaults to ":".
:type sep: str, optional
:param replace_char: The character to replace special characters with in column names, defaults to "_".
:type replace_char: str, optional
:param sanitized_columns: Whether to sanitize column names, defaults to False.
:type sanitized_columns: bool, optional
:return: The DataFrame with all complex data types flattened.
:rtype: DataFrame
.. note:: This function assumes the input DataFrame has a consistent schema across all rows. If you have files with
different schemas, process each separately instead.
.. example:: Example usage:
>>> data = [
(
1,
("Alice", 25),
{"A": 100, "B": 200},
["apple", "banana"],
{"key": {"nested_key": 10}},
{"A#": 1000, "B@": 2000},
),
(
2,
("Bob", 30),
{"A": 150, "B": 250},
["orange", "grape"],
{"key": {"nested_key": 20}},
{"A#": 1500, "B@": 2500},
),
]
>>> df = spark.createDataFrame(data)
>>> flattened_df = flatten_dataframe(df)
>>> flattened_df.show()
>>> flattened_df_with_hyphen = flatten_dataframe(df, replace_char="-")
>>> flattened_df_with_hyphen.show()
"""
complex_fields = _get_complex_fields(df)

while len(complex_fields) != 0:
col_name = list(complex_fields.keys())[0]

if isinstance(complex_fields[col_name], StructType):
df = flatten_struct(df, col_name, sep)

elif isinstance(complex_fields[col_name], ArrayType):
df = explode_array(df, col_name)

elif isinstance(complex_fields[col_name], MapType):
df = flatten_map(df, col_name, sep)

complex_fields = _get_complex_fields(df)

# Sanitize column names with the specified replace_char
if sanitized_columns:
sanitized_columns = [
sanitize_column_name(col_name, replace_char) for col_name in df.columns
]
df = df.toDF(*sanitized_columns)

return df


def regexp_extract_all(s: str, regexp: str) -> list[re.Match] | None:
"""Function uses the Python `re` library to extract regular expressions from a string (`s`) using a regex pattern (`regexp`).
Expand Down
163 changes: 163 additions & 0 deletions tests/test_functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -344,6 +344,169 @@ def test_regexp_extract_all(spark):
chispa.assert_column_equality(actual_df, "all_numbers", "expected")


def test_flatten_struct(spark):
data = [
(1, ("name1", "address1", 20)),
(2, ("name2", "address2", 30)),
(3, ("name3", "address3", 40)),
]
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField(
"details",
StructType(
[
StructField("name", StringType(), True),
StructField("address", StringType(), True),
StructField("age", IntegerType(), True),
]
),
True,
),
]
)
df = spark.createDataFrame(data, schema)
complex_fields = {"details": StructType([StructField("name", StringType(), True), StructField("address", StringType(), True), StructField("age", IntegerType(), True)])}
expected_schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("details:name", StringType(), True),
StructField("details:address", StringType(), True),
StructField("details:age", IntegerType(), True),
]
)
expected_data = [
(1, "name1", "address1", 20),
(2, "name2", "address2", 30),
(3, "name3", "address3", 40),
]
expected_df = spark.createDataFrame(expected_data, expected_schema)

flattened_df = flatten_struct(df, "details")
assert flattened_df.schema == expected_schema
assert flattened_df.collect() == expected_df.collect()


def test_flatten_map(spark):
data = [
(1, {"name": "Alice", "age": 25}),
(2, {"name": "Bob", "age": 30}),
(3, {"name": "Charlie", "age": 35}),
]
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("details", MapType(StringType(), StringType()), True),
]
)
df = spark.createDataFrame(data, schema)
expected_schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("details:name", StringType(), True),
StructField("details:age", StringType(), True),
]
)
expected_data = [
(1, "Alice", "25"),
(2, "Bob", "30"),
(3, "Charlie", "35"),
]
expected_df = spark.createDataFrame(expected_data, expected_schema)

flattened_df = flatten_map(df, "details")
assert flattened_df.schema == expected_schema
assert flattened_df.collect() == expected_df.collect()


def test_flatten_dataframe(spark):
# Define input data
data = [
(
1,
"John",
{"age": 30, "gender": "M", "address": {"city": "New York", "state": "NY"}},
[
{"type": "home", "number": "555-1234"},
{"type": "work", "number": "555-5678"},
],
),
(
2,
"Jane",
{"age": 25, "gender": "F", "address": {"city": "San Francisco", "state": "CA"}},
[{"type": "home", "number": "555-4321"}],
),
]
schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField(
"details",
StructType(
[
StructField("age", IntegerType(), True),
StructField("gender", StringType(), True),
StructField(
"address",
StructType(
[
StructField("city", StringType(), True),
StructField("state", StringType(), True),
]
),
True,
),
]
),
True,
),
StructField(
"phone_numbers",
ArrayType(
StructType(
[
StructField("type", StringType(), True),
StructField("number", StringType(), True),
]
),
True,
),
True,
),
]
)
df = spark.createDataFrame(data, schema)

# Define expected output
expected_data = [
(1, "John", 30, "M", "New York", "NY", "home", "555-1234"),
(1, "John", 30, "M", "New York", "NY", "work", "555-5678"),
(2, "Jane", 25, "F", "San Francisco", "CA", "home", "555-4321"),
]
expected_schema = StructType(
[
StructField("id", IntegerType(), True),
StructField("name", StringType(), True),
StructField("details:age", IntegerType(), True),
StructField("details:gender", StringType(), True),
StructField("details:address:city", StringType(), True),
StructField("details:address:state", StringType(), True),
StructField("phone:numbers:type", StringType(), True),
StructField("phone:numbers:number", StringType(), True),
]
)
expected_df = spark.createDataFrame(expected_data, expected_schema)

# Apply function to input data
result_df = flatten_dataframe(df)

# Check if result matches expected output
assert result_df.collect() == expected_df.collect()


def test_business_days_between(spark):
df = quinn.create_df(
spark,
Expand Down

0 comments on commit 074ce5f

Please sign in to comment.