From 5838463e80c4b526270a26a85186f02df9fda791 Mon Sep 17 00:00:00 2001 From: Amirhessam Tahmassebi Date: Sat, 29 Oct 2022 23:25:00 -0500 Subject: [PATCH] Fix xgboost feature selector callbacks (#161) Fix xgboost feature selector callbacks ... Co-authored-by: Amirhessam Tahmassebi --- src/slickml/selection/_xgboost.py | 4 ++++ tests/slickml/selection/test_xgboost.py | 16 ++++++++-------- 2 files changed, 12 insertions(+), 8 deletions(-) diff --git a/src/slickml/selection/_xgboost.py b/src/slickml/selection/_xgboost.py index b17f541..cc54e9a 100644 --- a/src/slickml/selection/_xgboost.py +++ b/src/slickml/selection/_xgboost.py @@ -1045,6 +1045,10 @@ def _callbacks(self) -> None: None """ if self.callbacks: + # TODO(amir): ditch print with logger + print( + "Warning: The `cv` will break if the `early_stopping_rounds` criterion was not satisfied.", + ) self.callbacks = [ xgb.callback.EvaluationMonitor( rank=0, diff --git a/tests/slickml/selection/test_xgboost.py b/tests/slickml/selection/test_xgboost.py index f6bc40b..206c49a 100644 --- a/tests/slickml/selection/test_xgboost.py +++ b/tests/slickml/selection/test_xgboost.py @@ -237,7 +237,6 @@ def test_reg_xgboostfeatureselector__passes__with_defaults( assert_that(cv_results_fig).is_instance_of(Figure) assert_that(feature_frequency_fig).is_instance_of(Figure) - # TODO(amir): change the code for `callback=True` and add the unit-test @pytest.mark.parametrize( ("clf_x_y"), [ @@ -257,14 +256,13 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs( ) -> None: """Validates `XGBoostFeatureSelector` instanation passes with valid inputs for classification.""" X, y = clf_x_y - # TODO(amir): callbacks=True has not been tested yet. The look of `self._cv()` would change - # therefore the logic for `if _feature_gain["feature"].str.contains("noisy").sum() != 0:` - # should be changed accordingly xfs = XGBoostFeatureSelector( n_iter=1, sparse_matrix=True, scale_std=True, params={"eval_metric": "logloss"}, + early_stopping_rounds=1000, + callbacks=True, ) xfs.fit(X, y) params = xfs.get_params() @@ -291,7 +289,7 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs( assert_that(xfs.nth_noise_threshold).is_instance_of(int) assert_that(xfs.nth_noise_threshold).is_equal_to(1) assert_that(xfs.early_stopping_rounds).is_instance_of(int) - assert_that(xfs.early_stopping_rounds).is_equal_to(20) + assert_that(xfs.early_stopping_rounds).is_equal_to(1000) assert_that(xfs.random_state).is_instance_of(int) assert_that(xfs.random_state).is_equal_to(1367) assert_that(xfs.metrics).is_instance_of(str) @@ -304,7 +302,7 @@ def test_clf_xgboostfeatureselector__passes__with_valid_inputs( assert_that(xfs.shuffle).is_true() assert_that(xfs.verbose_eval).is_instance_of(bool) assert_that(xfs.verbose_eval).is_false() - assert_that(xfs.callbacks).is_none() + assert_that(xfs.callbacks).is_not_none() assert_that(xfs.scale_mean).is_instance_of(bool) assert_that(xfs.scale_mean).is_false() assert_that(xfs.scale_std).is_instance_of(bool) @@ -370,6 +368,8 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs( sparse_matrix=True, scale_std=True, params={"eval_metric": "mae"}, + callbacks=True, + early_stopping_rounds=1000, ) xfs.fit(X, y) params = xfs.get_params() @@ -396,7 +396,7 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs( assert_that(xfs.nth_noise_threshold).is_instance_of(int) assert_that(xfs.nth_noise_threshold).is_equal_to(1) assert_that(xfs.early_stopping_rounds).is_instance_of(int) - assert_that(xfs.early_stopping_rounds).is_equal_to(20) + assert_that(xfs.early_stopping_rounds).is_equal_to(1000) assert_that(xfs.random_state).is_instance_of(int) assert_that(xfs.random_state).is_equal_to(1367) assert_that(xfs.metrics).is_instance_of(str) @@ -407,7 +407,7 @@ def test_reg_xgboostfeatureselector__passes__with_valid_inputs( assert_that(xfs.shuffle).is_true() assert_that(xfs.verbose_eval).is_instance_of(bool) assert_that(xfs.verbose_eval).is_false() - assert_that(xfs.callbacks).is_none() + assert_that(xfs.callbacks).is_not_none() assert_that(xfs.scale_mean).is_instance_of(bool) assert_that(xfs.scale_mean).is_false() assert_that(xfs.scale_std).is_instance_of(bool)