Skip to content

Commit

Permalink
Fix some tests
Browse files Browse the repository at this point in the history
 On branch feature/fix-tests
 Changes to be committed:
	modified:   pyproject.toml
	modified:   quinn/schema_helpers.py
	modified:   quinn/transformations.py
	modified:   tests/test_split_columns.py
  • Loading branch information
SemyonSinchenko committed Nov 18, 2023
1 parent f33e319 commit 09c3fd4
Show file tree
Hide file tree
Showing 4 changed files with 5 additions and 6 deletions.
1 change: 1 addition & 0 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ ignore = [
"TCH003", # I have no idea what is it about
"PLC1901", # Strange thing
"UP007", # Not supported in py3.6
"UP038", # Not supported in all py versions
]
extend-exclude = ["tests", "docs"]

Expand Down
4 changes: 2 additions & 2 deletions quinn/schema_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,7 @@ def print_schema_as_code(dtype: T.DataType) -> str:
def _repr_column(column: T.StructField) -> str:
res = []

if isinstance(column.dataType, (T.ArrayType | T.MapType | T.StructType)):
if isinstance(column.dataType, (T.ArrayType, T.MapType, T.StructType)):
res.append(f'StructField(\n\t"{column.name}",')
for line in print_schema_as_code(column.dataType).split("\n"):
res.append("\n\t")
Expand Down Expand Up @@ -164,6 +164,6 @@ def complex_fields(schema: T.StructType) -> dict[str, object]:
return {
field.name: field.dataType
for field in schema.fields
if isinstance(field.dataType, (T.ArrayType | T.StructType | T.MapType))
if isinstance(field.dataType, (T.ArrayType, T.StructType, T.MapType))
}

2 changes: 1 addition & 1 deletion quinn/transformations.py
Original file line number Diff line number Diff line change
Expand Up @@ -220,7 +220,7 @@ def fix_nullability(field: StructField, result_dict: dict) -> None:
return top_level_sorted_df

is_nested: bool = any(
isinstance(i.dataType, StructType | ArrayType)
isinstance(i.dataType, (StructType, ArrayType))
for i in top_level_sorted_df.schema
)

Expand Down
4 changes: 1 addition & 3 deletions tests/test_split_columns.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@
import chispa
import pytest

from pyspark.errors.exceptions.captured import PythonException


@auto_inject_fixtures("spark")
def test_split_columns(spark):
Expand Down Expand Up @@ -52,5 +50,5 @@ def test_split_columns_strict(spark):
delimiter="XX",
new_col_names=["student_first_name", "student_middle_name", "student_last_name"],
mode="strict", default="hi")
with pytest.raises(PythonException):
with pytest.raises(Exception): # there is no way to make it work for all the versions
df2.show()

0 comments on commit 09c3fd4

Please sign in to comment.