Skip to content

Commit

Permalink
Format code and fix linter
Browse files Browse the repository at this point in the history
  • Loading branch information
MrPowers committed Oct 8, 2023
1 parent d9c4cb7 commit 6a7c3ef
Show file tree
Hide file tree
Showing 9 changed files with 168 additions and 109 deletions.
10 changes: 8 additions & 2 deletions quinn/__init__.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
"""quinn API."""

from quinn.append_if_schema_identical import append_if_schema_identical
from quinn.split_columns import split_col
from quinn.dataframe_helpers import (
column_to_list,
create_df,
Expand Down Expand Up @@ -34,7 +33,14 @@
week_start_date,
)
from quinn.schema_helpers import print_schema_as_code
from quinn.transformations import snake_case_col_names, sort_columns, to_snake_case, with_columns_renamed, with_some_columns_renamed
from quinn.split_columns import split_col
from quinn.transformations import (
snake_case_col_names,
sort_columns,
to_snake_case,
with_columns_renamed,
with_some_columns_renamed,
)

# Use __all__ to let developers know what is part of the public API.
__all__ = [
Expand Down
12 changes: 8 additions & 4 deletions quinn/dataframe_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,9 @@ def column_to_list(df: DataFrame, col_name: str) -> list[Any]:


def two_columns_to_dictionary(
df: DataFrame, key_col_name: str, value_col_name: str,
df: DataFrame,
key_col_name: str,
value_col_name: str,
) -> dict[str, Any]:
"""Collect two columns as dictionary when first column is key and second is value.
Expand All @@ -48,11 +50,13 @@ def to_list_of_dictionaries(df: DataFrame) -> list[dict[str, Any]]:
:return: A list of dictionaries representing the rows in the DataFrame.
:rtype: List[Dict[str, Any]]
"""
return list(map(lambda r: r.asDict(), df.collect())) # noqa: C417
return list(map(lambda r: r.asDict(), df.collect())) # noqa: C417


def print_athena_create_table(
df: DataFrame, athena_table_name: str, s3location: str,
df: DataFrame,
athena_table_name: str,
s3location: str,
) -> None:
"""Generate the Athena create table statement for a given DataFrame.
Expand Down Expand Up @@ -110,5 +114,5 @@ def create_df(spark: SparkSession, rows_data, col_specs) -> DataFrame: # noqa:
:return: a new DataFrame
:rtype: DataFrame
"""
struct_fields = list(map(lambda x: StructField(*x), col_specs)) # noqa: C417
struct_fields = list(map(lambda x: StructField(*x), col_specs)) # noqa: C417
return spark.createDataFrame(data=rows_data, schema=StructType(struct_fields))
5 changes: 3 additions & 2 deletions quinn/dataframe_validator.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@
from pyspark.sql.types import StructType



class DataFrameMissingColumnError(ValueError):
"""Raise this when there's a DataFrame column error."""

Expand Down Expand Up @@ -40,7 +39,9 @@ def validate_presence_of_columns(df: DataFrame, required_col_names: list[str]) -


def validate_schema(
df: DataFrame, required_schema: StructType, ignore_nullable: bool = False, # noqa: FBT001,FBT002
df: DataFrame,
required_schema: StructType,
ignore_nullable: bool = False, # noqa: FBT001,FBT002
) -> None:
"""Function that validate if a given DataFrame has a given StructType as its schema.
Expand Down
16 changes: 15 additions & 1 deletion quinn/extensions/column_ext.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,15 +67,29 @@ def isNullOrBlank(self: Column) -> Column:


def isNotIn(self: Column, _list: list[Any]) -> Column:
"""To see if a value is not in a list of values.
:param self: Column object
:_list: list[Any]
:rtype: Column
"""
return ~(self.isin(_list))


def nullBetween(self: Column, lower: Column, upper: Column) -> Column:
"""To see if a value is between two values in a null friendly way.
:param self: Column object
:lower: Column
:upper: Column
:rtype: Column
"""
return when(lower.isNull() & upper.isNull(), False).otherwise(
when(self.isNull(), False).otherwise(
when(lower.isNull() & upper.isNotNull() & (self <= upper), True).otherwise(
when(
lower.isNotNull() & upper.isNull() & (self >= lower), True,
lower.isNotNull() & upper.isNull() & (self >= lower),
True,
).otherwise(self.between(lower, upper)),
),
),
Expand Down
122 changes: 64 additions & 58 deletions quinn/functions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,10 +3,10 @@
from typing import TYPE_CHECKING

if TYPE_CHECKING:
from numbers import Number
from numbers import Number

from pyspark.sql import Column
from pyspark.sql.functions import udf
from pyspark.sql import Column, DataFrame
from pyspark.sql.functions import udf


import re
Expand All @@ -15,16 +15,13 @@

import pyspark.sql.functions as F # noqa: N812
from pyspark.sql.types import (
ArrayType,
BooleanType,
StringType,
StructType,
MapType,
ArrayType,
BooleanType,
MapType,
StringType,
StructType,
)

from pyspark.sql import Column, DataFrame
from typing import Dict, List
from pathlib import Path

def single_space(col: Column) -> Column:
"""Function takes a column and replaces all the multiple white spaces with a single space.
Expand Down Expand Up @@ -108,7 +105,7 @@ def forall(f: Callable[[Any], bool]) -> udf:
:return: A spark UDF which accepts a list of arguments and returns True if all
elements pass through the given boolean function, False otherwise.
:rtype: UserDefinedFunction
"""
"""

def temp_udf(list_: list) -> bool:
return all(map(f, list_))
Expand All @@ -126,7 +123,7 @@ def multi_equals(value: Any) -> udf: # noqa: ANN401
"""

def temp_udf(*cols) -> bool: # noqa: ANN002
return all(map(lambda col: col == value, cols)) # noqa: C417
return all(map(lambda col: col == value, cols)) # noqa: C417

return F.udf(temp_udf, BooleanType())

Expand Down Expand Up @@ -191,14 +188,17 @@ def week_end_date(col: Column, week_end_day: str = "Sat") -> Column:
"Sat": 7,
}
return F.when(
F.dayofweek(col).eqNullSafe(F.lit(day_of_week_mapping[week_end_day])), col,
F.dayofweek(col).eqNullSafe(F.lit(day_of_week_mapping[week_end_day])),
col,
).otherwise(F.next_day(col, week_end_day))


def _raise_if_invalid_day(day: str) -> None:
valid_days = ["Mon", "Tue", "Wed", "Thu", "Fri", "Sat", "Sun"]
if day not in valid_days:
message = "The day you entered '{}' is not valid. Here are the valid days: [{}]".format(day, ",".join(valid_days))
message = "The day you entered '{}' is not valid. Here are the valid days: [{}]".format(
day, ",".join(valid_days),
)
raise ValueError(message)


Expand Down Expand Up @@ -230,13 +230,21 @@ def array_choice(col: Column, seed: int | None = None) -> Column:


@F.udf(returnType=ArrayType(StringType()))
def regexp_extract_all(s, regexp):
return None if s == None else re.findall(regexp, s)
def regexp_extract_all(s: Column, regexp: Column) -> Column:
"""Function uses the Python `re` library to extract regular expressions from a string (`s`) using a regex pattern (`regexp`).
It returns a list of all matches, or `None` if `s` is `None`.
def sanitize_column_name(name: str, replace_char: str = "_") -> str:
:param s: input string (`Column`)
:type s: str
:param regexp: string `re` pattern
:rtype: Column
"""
Sanitizes column names by replacing special characters with the specified character.
return None if s is None else re.findall(regexp, s)


def sanitize_column_name(name: str, replace_char: str = "_") -> str:
"""Sanitizes column names by replacing special characters with the specified character.
:param name: The original column name.
:type name: str
Expand All @@ -248,9 +256,8 @@ def sanitize_column_name(name: str, replace_char: str = "_") -> str:
return re.sub(r"[^a-zA-Z0-9_]", replace_char, name)


def _get_complex_fields(df: DataFrame) -> Dict[str, object]:
"""
Returns a dictionary of complex field names and their data types from the input DataFrame's schema.
def _get_complex_fields(df: DataFrame) -> dict[str, object]:
"""Returns a dictionary of complex field names and their data types from the input DataFrame's schema.
:param df: The input PySpark DataFrame.
:type df: DataFrame
Expand All @@ -265,8 +272,7 @@ def _get_complex_fields(df: DataFrame) -> Dict[str, object]:


def flatten_struct(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""
Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.
"""Flattens the specified StructType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
Expand All @@ -284,9 +290,9 @@ def flatten_struct(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
]
return df.select("*", *expanded).drop(F.col(f"`{col_name}`"))


def explode_array(df: DataFrame, col_name: str) -> DataFrame:
"""
Explodes the specified ArrayType column in the input DataFrame and returns a new DataFrame with the exploded column.
"""Explodes the specified ArrayType column in the input DataFrame and returns a new DataFrame with the exploded column.
:param df: The input PySpark DataFrame.
:type df: DataFrame
Expand All @@ -295,11 +301,13 @@ def explode_array(df: DataFrame, col_name: str) -> DataFrame:
:return: The DataFrame with the exploded ArrayType column.
:rtype: DataFrame
"""
return df.select("*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)).drop(col_name)
return df.select("*", F.explode_outer(F.col(f"`{col_name}`")).alias(col_name)).drop(
col_name,
)


def flatten_map(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""
Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.
"""Flattens the specified MapType column in the input DataFrame and returns a new DataFrame with the flattened columns.
:param df: The input PySpark DataFrame.
:type df: DataFrame
Expand All @@ -312,13 +320,21 @@ def flatten_map(df: DataFrame, col_name: str, sep: str = ":") -> DataFrame:
"""
keys_df = df.select(F.explode_outer(F.map_keys(F.col(f"`{col_name}`")))).distinct()
keys = [row[0] for row in keys_df.collect()]
key_cols = [F.col(f"`{col_name}`").getItem(k).alias(col_name + sep + k) for k in keys]
return df.select([F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols)
key_cols = [
F.col(f"`{col_name}`").getItem(k).alias(col_name + sep + k) for k in keys
]
return df.select(
[F.col(f"`{col}`") for col in df.columns if col != col_name] + key_cols,
)

def flatten_dataframe(df: DataFrame, sep: str = ":", replace_char: str = "_", sanitized_columns: bool = False) -> DataFrame:
"""
Flattens all complex data types (StructType, ArrayType, and MapType) in the input DataFrame and returns a
new DataFrame with the flattened columns.

def flatten_dataframe(
df: DataFrame,
sep: str = ":",
replace_char: str = "_",
sanitized_columns: bool = False, # noqa: FBT001, FBT002
) -> DataFrame:
"""Flattens the complex columns in the DataFrame.
:param df: The input PySpark DataFrame.
:type df: DataFrame
Expand Down Expand Up @@ -354,7 +370,7 @@ def flatten_dataframe(df: DataFrame, sep: str = ":", replace_char: str = "_", sa
{"A#": 1500, "B@": 2500},
),
]
>>> df = spark.createDataFrame(data)
>>> flattened_df = flatten_dataframe(df)
>>> flattened_df.show()
Expand All @@ -364,16 +380,16 @@ def flatten_dataframe(df: DataFrame, sep: str = ":", replace_char: str = "_", sa
complex_fields = _get_complex_fields(df)

while len(complex_fields) != 0:
col_name = list(complex_fields.keys())[0]
col_name = next(iter(complex_fields.keys()))

if isinstance(complex_fields[col_name], StructType):
df = flatten_struct(df, col_name, sep)
df = flatten_struct(df, col_name, sep) # noqa: PD901

elif isinstance(complex_fields[col_name], ArrayType):
df = explode_array(df, col_name)
df = explode_array(df, col_name) # noqa: PD901

elif isinstance(complex_fields[col_name], MapType):
df = flatten_map(df, col_name, sep)
df = flatten_map(df, col_name, sep) # noqa: PD901

complex_fields = _get_complex_fields(df)

Expand All @@ -382,25 +398,12 @@ def flatten_dataframe(df: DataFrame, sep: str = ":", replace_char: str = "_", sa
sanitized_columns = [
sanitize_column_name(col_name, replace_char) for col_name in df.columns
]
df = df.toDF(*sanitized_columns)
df = df.toDF(*sanitized_columns) # noqa: PD901

return df


# def regexp_extract_all(s: str, regexp: str) -> list[re.Match] | None:
# """Function uses the Python `re` library to extract regular expressions from a string (`s`) using a regex pattern (`regexp`).
#
# It returns a list of all matches, or `None` if `s` is `None`.
#
# :param s: input string (`Column`)
# :type s: str
# :param regexp: string `re` pattern
# :return: List of matches
# """
# return None if s is None else re.findall(regexp, s)


def business_days_between(start_date: Column, end_date: Column) -> Column:

def business_days_between(start_date: Column, end_date: Column) -> Column: # noqa: ARG001
"""Function takes two Spark `Columns` and returns a `Column` with the number of business days between the start and the end date.
:param start_date: The column with the start dates
Expand All @@ -419,7 +422,9 @@ def business_days_between(start_date: Column, end_date: Column) -> Column:


def uuid5(
col: Column, namespace: uuid.UUID = uuid.NAMESPACE_DNS, extra_string: str = "",
col: Column,
namespace: uuid.UUID = uuid.NAMESPACE_DNS,
extra_string: str = "",
) -> Column:
"""Function generates UUIDv5 from ``col`` and ``namespace``, optionally prepending an extra string to ``col``.
Expand All @@ -443,7 +448,8 @@ def uuid5(
variant_part = F.conv(variant_part, 16, 2)
variant_part = F.lpad(variant_part, 16, "0")
variant_part = F.concat(
F.lit("10"), F.substring(variant_part, 3, 16),
F.lit("10"),
F.substring(variant_part, 3, 16),
) # RFC 4122 variant.
variant_part = F.lower(F.conv(variant_part, 2, 16))
return F.concat_ws(
Expand Down
5 changes: 2 additions & 3 deletions quinn/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -52,9 +52,7 @@ def print_schema_as_code(dtype: T.DataType) -> str:
def _repr_column(column: T.StructField) -> str:
res = []

if (
isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType))
):
if isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType)):
res.append(f'StructField(\n\t"{column.name}",')
for line in print_schema_as_code(column.dataType).split("\n"):
res.append("\n\t")
Expand Down Expand Up @@ -85,6 +83,7 @@ def schema_from_csv(spark: SparkSession, file_path: str) -> T.StructType: # noq
:return: A StructType object representing the schema configuration
:rtype: pyspark.sql.types.StructType
"""

def _validate_json(metadata: str) -> dict:
if metadata is None:
return {}
Expand Down
6 changes: 5 additions & 1 deletion quinn/spark.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,11 @@ def __init__(
) -> None:
"""Initialize SparkSession."""
self.spark = self.set_up_spark(
app_name, self.master, conf, extra_dependencies, extra_files,
app_name,
self.master,
conf,
extra_dependencies,
extra_files,
)

@property
Expand Down
Loading

0 comments on commit 6a7c3ef

Please sign in to comment.