Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add unit tests to cover save-path flag in all visualization modules #175

Merged
merged 1 commit into from
Dec 26, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[BUG]: "
labels: ["bug"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/feature_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[FEATURE]: "
labels: ["enhancement"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/improve_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[DOCUMENTATION]: "
labels: ["documentation"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 5 additions & 0 deletions .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@ jobs:
#----------------------------------------------
runs-on: ${{ matrix.os }}
strategy:
# TODO(amir): currently this is `false` so we can use the `poetry-cache` of `python v3.8`
# once the stupid `glmnet` dependency is resolved, change `fail-fast = True`
fail-fast: false
matrix:
# TODO(amir): enable `windows-latest`, `macos-latest` and fix possible `poetry` issues and glmnet
Expand Down Expand Up @@ -58,6 +60,9 @@ jobs:
# therefore, all the CI jobs for those python versions failed at first, then we re-run the
# jobs, the cached venv using `python v3.8` will be retrieved and the jobs will run successfully
# ideally, we should be able to add `python-versions` here to distinguish between caches
# key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

stupid glmnet does not still let us do this!

# NOTE: `glmnet` has not been updated since 2020; trying to build it on-the-fly
# https:/civisanalytics/python-glmnet/issues/79
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
#----------------------------------------------
# ----- install dependencies -----
Expand Down
3,285 changes: 1,654 additions & 1,631 deletions poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ matplotlib = "^3.5,<3.6"
seaborn = "^0.12"


[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

following new poetry syntax


# --- package-management ---
pip = "^22.3"
Expand All @@ -87,12 +87,12 @@ pip = "^22.3"
poethepoet = "^0.16"

# --- testenv-management ---
tox = "^3.27"
tox = "^3.28"

# --- formatting ---
add-trailing-comma = "^2.2"
isort = "^5.10"
black = "^22.10"
add-trailing-comma = "^2.4"
isort = "^5.11"
black = "^22.12"
jupyter-black = "^0.3"

# --- linting ---
Expand Down Expand Up @@ -129,15 +129,17 @@ myst-parser = "^0.18"
furo = "^2022.9"

# --- jupyter ---
ipykernel = "^6.15"
ipykernel = "^6.20"
jupytext = "^1.14"

# --- monitoring ---
watchdog = "^2.1"

# --- image manipulation ---
pillow = "^9.3.0"

[build-system]
requires = ["poetry-core>=1.3.2"]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"


Expand Down
2 changes: 1 addition & 1 deletion src/slickml/base/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def fit(
-------
None
"""
...
... # pragma: no cover

# TODO(amir): check the `y_train` type; maybe we need to have `list_to_array()` in utils?
def _dtrain(
Expand Down
4 changes: 2 additions & 2 deletions src/slickml/base/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def plot(self) -> Figure:
-------
Figure
"""
...
... # pragma: no cover

def get_metrics(
self,
Expand All @@ -41,4 +41,4 @@ def get_metrics(
-------
Union[pd.DataFrame, Dict[str, Optional[float]]]
"""
...
... # pragma: no cover
2 changes: 1 addition & 1 deletion src/slickml/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display import display
from IPython.display import display
from matplotlib.figure import Figure
from sklearn.metrics import (
accuracy_score,
Expand Down
2 changes: 1 addition & 1 deletion src/slickml/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import scipy as scp
import seaborn as sns
from IPython.core.display import display
from IPython.display import display
from matplotlib.figure import Figure
from sklearn.metrics import (
explained_variance_score,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_glmnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def plot_glmnet_cv_results(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -447,7 +446,6 @@ def plot_glmnet_coeff_path(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def plot_binary_classification_metrics(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -403,7 +402,6 @@ def plot_regression_metrics(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def plot_xgb_feature_importance(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -337,7 +336,6 @@ def plot_xgb_cv_results(
var_name="test_std_color",
dtypes=str,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
62 changes: 61 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
import importlib.resources as pkg_resources
import json
from pathlib import Path # noqa
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
from pytest import CaptureFixture, FixtureRequest
from assertpy import assert_that
from PIL import Image
from PIL.ImageFile import ImageFile
from PIL.PngImagePlugin import PngImageFile
from pytest import CaptureFixture, FixtureRequest, TempPathFactory
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split

from tests import resources


@pytest.fixture(scope="session")
def figure_path(tmp_path_factory: TempPathFactory) -> Path:
"""Returns a temporary path to save figures.

Parameters
----------
tmp_path_factory : TempPathFactory
Pytest's temporary path factory

Returns
-------
Path
"""
path = tmp_path_factory.mktemp("results")

return path.absolute()


@pytest.fixture(scope="session")
def clf_train_test_x_y(
request: FixtureRequest,
Expand Down Expand Up @@ -356,3 +380,39 @@ def _dummy_sparse_matrix() -> csr_matrix:
shape=(3, 3),
dtype=np.float64,
)


def _validate_figure_type_and_size(
path: Path,
expected_size: Tuple[int, int],
expected_type: ImageFile = PngImageFile,
) -> None:
"""Validates exported figure path's type and size.

Parameters
----------
path : Path
Temporary file path to save the figure

expected_size : Tuple[int, int]
Figure's size

expected_type : ImageFile
Figure's type, by default `PngImageFile` which cover `.png` files

Returns
-------
None
"""
# TODO(amir): currently, the `PIL.Image()` is used to load the saved figures and this would end up
# different results per operating systems. Therefore, for now defining a `tolerance` range ~ 5%
# to get away with the errors. Ideally, we should be able to match the `exact` size
_IMAGE_SIZE_ERROR_TOLERANCE_ERROR = 0.05
with Image.open(path) as img:
assert_that(img).is_instance_of(expected_type)
assert_that(img.size).is_instance_of(tuple)
npt.assert_allclose(
actual=img.size,
desired=expected_size,
rtol=_IMAGE_SIZE_ERROR_TOLERANCE_ERROR,
)
4 changes: 4 additions & 0 deletions tests/slickml/base/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_extended_enum_instantiation__passes__with_default_values(self) -> None:
assert_that(str(FooBarBazQux.BAR)).is_instance_of(str)
assert_that(str(FooBarBazQux.BAZ)).is_instance_of(str)
assert_that(str(FooBarBazQux.QUX)).is_instance_of(str)
assert_that(repr(FooBarBazQux.FOO)).is_instance_of(str)
assert_that(repr(FooBarBazQux.BAR)).is_instance_of(str)
assert_that(repr(FooBarBazQux.BAZ)).is_instance_of(str)
assert_that(repr(FooBarBazQux.QUX)).is_instance_of(str)
assert_that(FooBarBazQux.names()).is_instance_of(list)
assert_that(FooBarBazQux.names()).is_iterable()
assert_that(FooBarBazQux.values()).is_instance_of(list)
Expand Down
77 changes: 76 additions & 1 deletion tests/slickml/classification/test_glmnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path # noqa
from typing import Any, Dict, List, Tuple, Union

import numpy as np
Expand All @@ -9,7 +10,7 @@
from matplotlib.figure import Figure

from slickml.classification import GLMNetCVClassifier
from tests.conftest import _ids
from tests.conftest import _ids, _validate_figure_type_and_size


# TODO(amir): add lolipop plot for coeff + unit-test
Expand Down Expand Up @@ -434,3 +435,77 @@ def test_glmnetcvclassifier_shap_plots__passes__with_valid_inputs(
clf.plot_shap_summary(**summary_kwargs)

assert_that(shap_waterfall_fig).is_instance_of(Figure)

@pytest.mark.parametrize(
("clf_train_test_x_y"),
[
("dataframe"),
],
indirect=["clf_train_test_x_y"],
ids=_ids,
)
def test_glmnetcvclassifier_plots__passes__with_valid_save_paths(
self,
clf_train_test_x_y: Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray],
figure_path: Path,
) -> None:
"""Validates `GLMNetCVClassifier` saving plots passes with valid paths."""
X_train, X_test, y_train, y_test = clf_train_test_x_y
clf = GLMNetCVClassifier()
clf.fit(X_train, y_train)
_ = clf.predict_proba(X_test, y_test)
cv_results_fig_path = figure_path / "cv_results_fig.png" # type: ignore
coeff_path_fig_path = figure_path / "coeff_path_fig.png" # type: ignore
shap_waterfall_fig_path = figure_path / "shap_waterfall_fig.png" # type: ignore
shap_summary_fig_path = figure_path / "shap_summary_fig.png" # type: ignore

clf.plot_cv_results(
xlabel="foo",
ylabel="bar",
title="baz",
legend=True,
legendloc=2,
save_path=str(cv_results_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_coeff_path(
xlabel="foo",
ylabel="bar",
title="baz",
legend=True,
legendloc=2,
save_path=str(coeff_path_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_shap_waterfall(
save_path=str(shap_waterfall_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_shap_summary(
save_path=str(shap_summary_fig_path),
display_plot=False,
)

assert_that(cv_results_fig_path.parts[-1]).is_equal_to("cv_results_fig.png")
_validate_figure_type_and_size(
path=cv_results_fig_path,
expected_size=(1385, 930),
)
assert_that(coeff_path_fig_path.parts[-1]).is_equal_to("coeff_path_fig.png")
_validate_figure_type_and_size(
path=coeff_path_fig_path,
expected_size=(1627, 930),
)
assert_that(shap_waterfall_fig_path.parts[-1]).is_equal_to("shap_waterfall_fig.png")
_validate_figure_type_and_size(
path=shap_waterfall_fig_path,
expected_size=(1375, 974),
)
assert_that(shap_summary_fig_path.parts[-1]).is_equal_to("shap_summary_fig.png")
_validate_figure_type_and_size(
path=shap_summary_fig_path,
expected_size=(1474, 760),
)
Loading