From 177bd50c629576d00057b08908768650875aa023 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 26 Sep 2023 21:04:30 -0500 Subject: [PATCH 1/5] [python-package] reorganize early stopping callback --- python-package/lightgbm/callback.py | 64 ++++++++++++++++++++--------- 1 file changed, 45 insertions(+), 19 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 2f77ee740c75..40e3dc47f8dc 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -229,7 +229,12 @@ def __call__(self, env: CallbackEnv) -> None: if new_param != env.params.get(key, None): new_parameters[key] = new_param if new_parameters: - env.model.reset_parameter(new_parameters) + if isinstance(env.model, Booster): + env.model.reset_parameter(new_parameters) + else: + # CVBooster holds a list of Booster objects, each needs to be updated + for i in range(len(env.model.boosters)): + env.model.boosters[i].reset_parameter(new_parameters) env.params.update(new_parameters) @@ -291,32 +296,49 @@ def _gt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: return curr_score < best_score - delta - def _is_train_set(self, ds_name: str, eval_name: str, train_name: str) -> bool: - return (ds_name == "cv_agg" and eval_name == "train") or ds_name == train_name + def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool: + """Check, by name, if a given Dataset is the training data""" + # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set + # and those metrics are considered for early stopping + if ds_name == "cv_agg" and eval_name == "train": + return True + + # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name + if isinstance(env.model, Booster): + if ds_name == env.model._train_data_name: + return True + + return False def _init(self, env: CallbackEnv) -> None: if env.evaluation_result_list is None or env.evaluation_result_list == []: raise ValueError( "For early stopping, at least one dataset and eval metric is required for evaluation" ) + + if self.stopping_rounds <= 0: + raise ValueError(f"stopping_rounds should be greater than zero. got: {self.stopping_rounds}") + is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting")) - only_train_set = ( - len(env.evaluation_result_list) == 1 - and self._is_train_set( - ds_name=env.evaluation_result_list[0][0], - eval_name=env.evaluation_result_list[0][1].split(" ")[0], - train_name=env.model._train_data_name) - ) - self.enabled = not is_dart and not only_train_set - if not self.enabled: - if is_dart: - _log_warning('Early stopping is not available in dart mode') - elif only_train_set: - _log_warning('Only training set found, disabling early stopping.') + if is_dart: + self.enabled = False + _log_warning('Early stopping is not available in dart mode') return - if self.stopping_rounds <= 0: - raise ValueError("stopping_rounds should be greater than zero.") + # validation sets are guaranteed to not be identical to the training data in cv() + if isinstance(env.model, Booster): + only_train_set = ( + len(env.evaluation_result_list) == 1 + and self._is_train_set( + ds_name=env.evaluation_result_list[0][0], + eval_name=env.evaluation_result_list[0][1].split(" ")[0], + env=env + ) + ) + if only_train_set: + self.enabled = False + _log_warning('Only training set found, disabling early stopping.') + return if self.verbose: _log_info(f"Training until validation scores don't improve for {self.stopping_rounds} rounds") @@ -395,7 +417,11 @@ def __call__(self, env: CallbackEnv) -> None: eval_name_splitted = env.evaluation_result_list[i][1].split(" ") if self.first_metric_only and self.first_metric != eval_name_splitted[-1]: continue # use only the first metric for early stopping - if self._is_train_set(env.evaluation_result_list[i][0], eval_name_splitted[0], env.model._train_data_name): + if self._is_train_set( + ds_name=env.evaluation_result_list[i][0], + eval_name=eval_name_splitted[0], + env=env + ): continue # train data for lgb.cv or sklearn wrapper (underlying lgb.train) elif env.iteration - self.best_iter[i] >= self.stopping_rounds: if self.verbose: From 4af2b9aded20598af9da508bc663e8d8e582cf66 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 26 Sep 2023 21:13:40 -0500 Subject: [PATCH 2/5] linting --- python-package/lightgbm/callback.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 40e3dc47f8dc..555822a40131 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -297,7 +297,7 @@ def _lt_delta(self, curr_score: float, best_score: float, delta: float) -> bool: return curr_score < best_score - delta def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool: - """Check, by name, if a given Dataset is the training data""" + """Check, by name, if a given Dataset is the training data.""" # for lgb.cv() with eval_train_metric=True, evaluation is also done on the training set # and those metrics are considered for early stopping if ds_name == "cv_agg" and eval_name == "train": From 56f73b419f11ee2ebeba5076688acab45d886260 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 3 Oct 2023 21:48:33 -0500 Subject: [PATCH 3/5] Apply suggestions from code review MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: José Morales --- python-package/lightgbm/callback.py | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 555822a40131..dcbb51cd2af2 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -233,8 +233,8 @@ def __call__(self, env: CallbackEnv) -> None: env.model.reset_parameter(new_parameters) else: # CVBooster holds a list of Booster objects, each needs to be updated - for i in range(len(env.model.boosters)): - env.model.boosters[i].reset_parameter(new_parameters) + for booster in env.model.boosters: + booster.reset_parameter(new_parameters) env.params.update(new_parameters) @@ -304,9 +304,8 @@ def _is_train_set(self, ds_name: str, eval_name: str, env: CallbackEnv) -> bool: return True # for lgb.train(), it's possible to pass the training data via valid_sets with any eval_name - if isinstance(env.model, Booster): - if ds_name == env.model._train_data_name: - return True + if isinstance(env.model, Booster) and ds_name == env.model._train_data_name: + return True return False From bd3366a6fa9335e3d935e6b91fd6d8db25dea304 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Tue, 3 Oct 2023 21:57:14 -0500 Subject: [PATCH 4/5] move validation up --- python-package/lightgbm/callback.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index dcbb51cd2af2..0aa7d7fb8131 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -272,6 +272,10 @@ def __init__( verbose: bool = True, min_delta: Union[float, List[float]] = 0.0 ) -> None: + + if stopping_rounds <= 0: + raise ValueError(f"stopping_rounds should be greater than zero. got: {stopping_rounds}") + self.order = 30 self.before_iteration = False @@ -315,9 +319,6 @@ def _init(self, env: CallbackEnv) -> None: "For early stopping, at least one dataset and eval metric is required for evaluation" ) - if self.stopping_rounds <= 0: - raise ValueError(f"stopping_rounds should be greater than zero. got: {self.stopping_rounds}") - is_dart = any(env.params.get(alias, "") == 'dart' for alias in _ConfigAliases.get("boosting")) if is_dart: self.enabled = False From 7a98d82a31c94cdbfcbd2a8cd0c4a557281305a3 Mon Sep 17 00:00:00 2001 From: James Lamb Date: Wed, 4 Oct 2023 22:11:12 -0500 Subject: [PATCH 5/5] fix tests --- python-package/lightgbm/callback.py | 4 ++-- tests/python_package_test/test_callback.py | 11 +++++++++++ tests/python_package_test/test_engine.py | 4 ++-- 3 files changed, 15 insertions(+), 4 deletions(-) diff --git a/python-package/lightgbm/callback.py b/python-package/lightgbm/callback.py index 0aa7d7fb8131..b68bb63c7f41 100644 --- a/python-package/lightgbm/callback.py +++ b/python-package/lightgbm/callback.py @@ -273,8 +273,8 @@ def __init__( min_delta: Union[float, List[float]] = 0.0 ) -> None: - if stopping_rounds <= 0: - raise ValueError(f"stopping_rounds should be greater than zero. got: {stopping_rounds}") + if not isinstance(stopping_rounds, int) or stopping_rounds <= 0: + raise ValueError(f"stopping_rounds should be an integer and greater than 0. got: {stopping_rounds}") self.order = 30 self.before_iteration = False diff --git a/tests/python_package_test/test_callback.py b/tests/python_package_test/test_callback.py index cb5dc707bf43..f93ca837f8b9 100644 --- a/tests/python_package_test/test_callback.py +++ b/tests/python_package_test/test_callback.py @@ -21,6 +21,17 @@ def test_early_stopping_callback_is_picklable(serializer): assert callback.stopping_rounds == rounds +def test_early_stopping_callback_rejects_invalid_stopping_rounds_with_informative_errors(): + with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: 0"): + lgb.early_stopping(stopping_rounds=0) + + with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: -1"): + lgb.early_stopping(stopping_rounds=-1) + + with pytest.raises(ValueError, match="stopping_rounds should be an integer and greater than 0. got: neverrrr"): + lgb.early_stopping(stopping_rounds="neverrrr") + + @pytest.mark.parametrize('serializer', SERIALIZERS) def test_log_evaluation_callback_is_picklable(serializer): periods = 42 diff --git a/tests/python_package_test/test_engine.py b/tests/python_package_test/test_engine.py index b46526bcfaf6..2f592d43b243 100644 --- a/tests/python_package_test/test_engine.py +++ b/tests/python_package_test/test_engine.py @@ -4501,9 +4501,9 @@ def test_train_raises_informative_error_if_any_valid_sets_are_not_dataset_object def test_train_raises_informative_error_for_params_of_wrong_type(): X, y = make_synthetic_regression() - params = {"early_stopping_round": "too-many"} + params = {"num_leaves": "too-many"} dtrain = lgb.Dataset(X, label=y) - with pytest.raises(lgb.basic.LightGBMError, match="Parameter early_stopping_round should be of type int, got \"too-many\""): + with pytest.raises(lgb.basic.LightGBMError, match="Parameter num_leaves should be of type int, got \"too-many\""): lgb.train(params, dtrain)