Skip to content

Commit

Permalink
Fix xgboost feature selector callbacks ...
Browse files Browse the repository at this point in the history
  • Loading branch information
Amirhessam Tahmassebi committed Oct 30, 2022
1 parent cd594b8 commit 9935afe
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 8 deletions.
4 changes: 4 additions & 0 deletions src/slickml/selection/_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -1045,6 +1045,10 @@ def _callbacks(self) -> None:
None
"""
if self.callbacks:
# TODO(amir): ditch print with logger
print(
"Warning: The `cv` will break if the `early_stopping_rounds` criterion was not satisfied."
)
self.callbacks = [
xgb.callback.EvaluationMonitor(
rank=0,
Expand Down
16 changes: 8 additions & 8 deletions tests/slickml/selection/test_xgboost.py
Original file line number Diff line number Diff line change
Expand Up @@ -237,7 +237,6 @@ def test_reg_xgboostfeatureselector__passes__with_defaults(
assert_that(cv_results_fig).is_instance_of(Figure)
assert_that(feature_frequency_fig).is_instance_of(Figure)

# TODO(amir): change the code for `callback=True` and add the unit-test
@pytest.mark.parametrize(
("clf_x_y"),
[
Expand All @@ -257,14 +256,13 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs(
) -> None:
"""Validates `XGBoostFeatureSelector` instanation passes with valid inputs for classification."""
X, y = clf_x_y
# TODO(amir): callbacks=True has not been tested yet. The look of `self._cv()` would change
# therefore the logic for `if _feature_gain["feature"].str.contains("noisy").sum() != 0:`
# should be changed accordingly
xfs = XGBoostFeatureSelector(
n_iter=1,
sparse_matrix=True,
scale_std=True,
params={"eval_metric": "logloss"},
early_stopping_rounds=1000,
callbacks=True,
)
xfs.fit(X, y)
params = xfs.get_params()
Expand All @@ -291,7 +289,7 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs(
assert_that(xfs.nth_noise_threshold).is_instance_of(int)
assert_that(xfs.nth_noise_threshold).is_equal_to(1)
assert_that(xfs.early_stopping_rounds).is_instance_of(int)
assert_that(xfs.early_stopping_rounds).is_equal_to(20)
assert_that(xfs.early_stopping_rounds).is_equal_to(1000)
assert_that(xfs.random_state).is_instance_of(int)
assert_that(xfs.random_state).is_equal_to(1367)
assert_that(xfs.metrics).is_instance_of(str)
Expand All @@ -304,7 +302,7 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs(
assert_that(xfs.shuffle).is_true()
assert_that(xfs.verbose_eval).is_instance_of(bool)
assert_that(xfs.verbose_eval).is_false()
assert_that(xfs.callbacks).is_none()
assert_that(xfs.callbacks).is_not_none()
assert_that(xfs.scale_mean).is_instance_of(bool)
assert_that(xfs.scale_mean).is_false()
assert_that(xfs.scale_std).is_instance_of(bool)
Expand Down Expand Up @@ -370,6 +368,8 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs(
sparse_matrix=True,
scale_std=True,
params={"eval_metric": "mae"},
callbacks=True,
early_stopping_rounds=1000,
)
xfs.fit(X, y)
params = xfs.get_params()
Expand All @@ -396,7 +396,7 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs(
assert_that(xfs.nth_noise_threshold).is_instance_of(int)
assert_that(xfs.nth_noise_threshold).is_equal_to(1)
assert_that(xfs.early_stopping_rounds).is_instance_of(int)
assert_that(xfs.early_stopping_rounds).is_equal_to(20)
assert_that(xfs.early_stopping_rounds).is_equal_to(1000)
assert_that(xfs.random_state).is_instance_of(int)
assert_that(xfs.random_state).is_equal_to(1367)
assert_that(xfs.metrics).is_instance_of(str)
Expand All @@ -407,7 +407,7 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs(
assert_that(xfs.shuffle).is_true()
assert_that(xfs.verbose_eval).is_instance_of(bool)
assert_that(xfs.verbose_eval).is_false()
assert_that(xfs.callbacks).is_none()
assert_that(xfs.callbacks).is_not_none()
assert_that(xfs.scale_mean).is_instance_of(bool)
assert_that(xfs.scale_mean).is_false()
assert_that(xfs.scale_std).is_instance_of(bool)
Expand Down

0 comments on commit 9935afe

Please sign in to comment.