Skip to content

Commit

Permalink
Fix visualization unit tests ...
Browse files Browse the repository at this point in the history
  • Loading branch information
Amirhessam Tahmassebi committed Dec 26, 2022
1 parent 011f4d6 commit c4b8b75
Show file tree
Hide file tree
Showing 24 changed files with 2,268 additions and 1,675 deletions.
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/bug_report.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[BUG]: "
labels: ["bug"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/feature_request.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[FEATURE]: "
labels: ["enhancement"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 0 additions & 5 deletions .github/ISSUE_TEMPLATE/improve_documentation.yml
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,6 @@ title: "[DOCUMENTATION]: "
labels: ["documentation"]
assignees:
- amirhessam88
- Tsmith5151
- richardNam
- mkohram
- b-mohebali
- nicholasjma
body:
- type: markdown
attributes:
Expand Down
5 changes: 4 additions & 1 deletion .github/workflows/ci.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ jobs:
#----------------------------------------------
runs-on: ${{ matrix.os }}
strategy:
fail-fast: false
fail-fast: true
matrix:
# TODO(amir): enable `windows-latest`, `macos-latest` and fix possible `poetry` issues and glmnet
os: ["ubuntu-latest"]
Expand Down Expand Up @@ -58,6 +58,9 @@ jobs:
# therefore, all the CI jobs for those python versions failed at first, then we re-run the
# jobs, the cached venv using `python v3.8` will be retrieved and the jobs will run successfully
# ideally, we should be able to add `python-versions` here to distinguish between caches
# key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
# NOTE: `glmnet` has not been updated since 2020; trying to build it on-the-fly
# https:/civisanalytics/python-glmnet/issues/79
key: venv-${{ runner.os }}-${{ hashFiles('**/poetry.lock') }}
#----------------------------------------------
# ----- install dependencies -----
Expand Down
3,285 changes: 1,654 additions & 1,631 deletions poetry.lock

Large diffs are not rendered by default.

16 changes: 9 additions & 7 deletions pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -78,7 +78,7 @@ matplotlib = "^3.5,<3.6"
seaborn = "^0.12"


[tool.poetry.dev-dependencies]
[tool.poetry.group.dev.dependencies]

# --- package-management ---
pip = "^22.3"
Expand All @@ -87,12 +87,12 @@ pip = "^22.3"
poethepoet = "^0.16"

# --- testenv-management ---
tox = "^3.27"
tox = "^3.28"

# --- formatting ---
add-trailing-comma = "^2.2"
isort = "^5.10"
black = "^22.10"
add-trailing-comma = "^2.4"
isort = "^5.11"
black = "^22.12"
jupyter-black = "^0.3"

# --- linting ---
Expand Down Expand Up @@ -129,15 +129,17 @@ myst-parser = "^0.18"
furo = "^2022.9"

# --- jupyter ---
ipykernel = "^6.15"
ipykernel = "^6.20"
jupytext = "^1.14"

# --- monitoring ---
watchdog = "^2.1"

# --- image manipulation ---
pillow = "^9.3.0"

[build-system]
requires = ["poetry-core>=1.3.2"]
requires = ["poetry-core>=1.4.0"]
build-backend = "poetry.core.masonry.api"


Expand Down
2 changes: 1 addition & 1 deletion src/slickml/base/_estimator.py
Original file line number Diff line number Diff line change
Expand Up @@ -133,7 +133,7 @@ def fit(
-------
None
"""
...
... # pragma: no cover

# TODO(amir): check the `y_train` type; maybe we need to have `list_to_array()` in utils?
def _dtrain(
Expand Down
4 changes: 2 additions & 2 deletions src/slickml/base/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def plot(self) -> Figure:
-------
Figure
"""
...
... # pragma: no cover

def get_metrics(
self,
Expand All @@ -41,4 +41,4 @@ def get_metrics(
-------
Union[pd.DataFrame, Dict[str, Optional[float]]]
"""
...
... # pragma: no cover
2 changes: 1 addition & 1 deletion src/slickml/metrics/_classification.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from IPython.core.display import display
from IPython.display import display
from matplotlib.figure import Figure
from sklearn.metrics import (
accuracy_score,
Expand Down
2 changes: 1 addition & 1 deletion src/slickml/metrics/_regression.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import pandas as pd
import scipy as scp
import seaborn as sns
from IPython.core.display import display
from IPython.display import display
from matplotlib.figure import Figure
from sklearn.metrics import (
explained_variance_score,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_glmnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,6 @@ def plot_glmnet_cv_results(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -447,7 +446,6 @@ def plot_glmnet_coeff_path(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,6 @@ def plot_binary_classification_metrics(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -403,7 +402,6 @@ def plot_regression_metrics(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
2 changes: 0 additions & 2 deletions src/slickml/visualization/_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -146,7 +146,6 @@ def plot_xgb_feature_importance(
var_name="return_fig",
dtypes=bool,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down Expand Up @@ -337,7 +336,6 @@ def plot_xgb_cv_results(
var_name="test_std_color",
dtypes=str,
)
# TODO(amir): double check this
if save_path:
check_var(
save_path,
Expand Down
62 changes: 61 additions & 1 deletion tests/conftest.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,41 @@
import importlib.resources as pkg_resources
import json
from pathlib import Path # noqa
from typing import Any, Dict, List, Optional, Tuple, Union

import numpy as np
import numpy.testing as npt
import pandas as pd
import pytest
from pytest import CaptureFixture, FixtureRequest
from assertpy import assert_that
from PIL import Image
from PIL.ImageFile import ImageFile
from PIL.PngImagePlugin import PngImageFile
from pytest import CaptureFixture, FixtureRequest, TempPathFactory
from scipy.sparse import csr_matrix
from sklearn.model_selection import train_test_split

from tests import resources


@pytest.fixture(scope="session")
def figure_path(tmp_path_factory: TempPathFactory) -> Path:
"""Returns a temporary path to save figures.
Parameters
----------
tmp_path_factory : TempPathFactory
Pytest's temporary path factory
Returns
-------
Path
"""
path = tmp_path_factory.mktemp("results")

return path.absolute()


@pytest.fixture(scope="session")
def clf_train_test_x_y(
request: FixtureRequest,
Expand Down Expand Up @@ -356,3 +380,39 @@ def _dummy_sparse_matrix() -> csr_matrix:
shape=(3, 3),
dtype=np.float64,
)


def _validate_figure_type_and_size(
path: Path,
expected_size: Tuple[int, int],
expected_type: ImageFile = PngImageFile,
) -> None:
"""Validates exported figure path's type and size.
Parameters
----------
path : Path
Temporary file path to save the figure
expected_size : Tuple[int, int]
Figure's size
expected_type : ImageFile
Figure's type, by default `PngImageFile` which cover `.png` files
Returns
-------
None
"""
# TODO(amir): currently, the `PIL.Image()` is used to load the saved figures and this would end up
# different results per operating systems. Therefore, for now defining a `tolerance` range ~ 5%
# to get away with the errors. Ideally, we should be able to match the `exact` size
_IMAGE_SIZE_ERROR_TOLERANCE_ERROR = 0.05
with Image.open(path) as img:
assert_that(img).is_instance_of(expected_type)
assert_that(img.size).is_instance_of(tuple)
npt.assert_allclose(
actual=img.size,
desired=expected_size,
rtol=_IMAGE_SIZE_ERROR_TOLERANCE_ERROR,
)
4 changes: 4 additions & 0 deletions tests/slickml/base/test_enum.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,10 @@ def test_extended_enum_instantiation__passes__with_default_values(self) -> None:
assert_that(str(FooBarBazQux.BAR)).is_instance_of(str)
assert_that(str(FooBarBazQux.BAZ)).is_instance_of(str)
assert_that(str(FooBarBazQux.QUX)).is_instance_of(str)
assert_that(repr(FooBarBazQux.FOO)).is_instance_of(str)
assert_that(repr(FooBarBazQux.BAR)).is_instance_of(str)
assert_that(repr(FooBarBazQux.BAZ)).is_instance_of(str)
assert_that(repr(FooBarBazQux.QUX)).is_instance_of(str)
assert_that(FooBarBazQux.names()).is_instance_of(list)
assert_that(FooBarBazQux.names()).is_iterable()
assert_that(FooBarBazQux.values()).is_instance_of(list)
Expand Down
77 changes: 76 additions & 1 deletion tests/slickml/classification/test_glmnet.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
from pathlib import Path # noqa
from typing import Any, Dict, List, Tuple, Union

import numpy as np
Expand All @@ -9,7 +10,7 @@
from matplotlib.figure import Figure

from slickml.classification import GLMNetCVClassifier
from tests.conftest import _ids
from tests.conftest import _ids, _validate_figure_type_and_size


# TODO(amir): add lolipop plot for coeff + unit-test
Expand Down Expand Up @@ -434,3 +435,77 @@ def test_glmnetcvclassifier_shap_plots__passes__with_valid_inputs(
clf.plot_shap_summary(**summary_kwargs)

assert_that(shap_waterfall_fig).is_instance_of(Figure)

@pytest.mark.parametrize(
("clf_train_test_x_y"),
[
("dataframe"),
],
indirect=["clf_train_test_x_y"],
ids=_ids,
)
def test_glmnetcvclassifier_plots__passes__with_valid_save_paths(
self,
clf_train_test_x_y: Tuple[pd.DataFrame, pd.DataFrame, np.ndarray, np.ndarray],
figure_path: Path,
) -> None:
"""Validates `GLMNetCVClassifier` saving plots passes with valid paths."""
X_train, X_test, y_train, y_test = clf_train_test_x_y
clf = GLMNetCVClassifier()
clf.fit(X_train, y_train)
_ = clf.predict_proba(X_test, y_test)
cv_results_fig_path = figure_path / "cv_results_fig.png" # type: ignore
coeff_path_fig_path = figure_path / "coeff_path_fig.png" # type: ignore
shap_waterfall_fig_path = figure_path / "shap_waterfall_fig.png" # type: ignore
shap_summary_fig_path = figure_path / "shap_summary_fig.png" # type: ignore

clf.plot_cv_results(
xlabel="foo",
ylabel="bar",
title="baz",
legend=True,
legendloc=2,
save_path=str(cv_results_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_coeff_path(
xlabel="foo",
ylabel="bar",
title="baz",
legend=True,
legendloc=2,
save_path=str(coeff_path_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_shap_waterfall(
save_path=str(shap_waterfall_fig_path),
display_plot=False,
return_fig=False,
)
clf.plot_shap_summary(
save_path=str(shap_summary_fig_path),
display_plot=False,
)

assert_that(cv_results_fig_path.parts[-1]).is_equal_to("cv_results_fig.png")
_validate_figure_type_and_size(
path=cv_results_fig_path,
expected_size=(1385, 930),
)
assert_that(coeff_path_fig_path.parts[-1]).is_equal_to("coeff_path_fig.png")
_validate_figure_type_and_size(
path=coeff_path_fig_path,
expected_size=(1627, 930),
)
assert_that(shap_waterfall_fig_path.parts[-1]).is_equal_to("shap_waterfall_fig.png")
_validate_figure_type_and_size(
path=shap_waterfall_fig_path,
expected_size=(1375, 974),
)
assert_that(shap_summary_fig_path.parts[-1]).is_equal_to("shap_summary_fig.png")
_validate_figure_type_and_size(
path=shap_summary_fig_path,
expected_size=(1474, 760),
)
Loading

0 comments on commit c4b8b75

Please sign in to comment.