Skip to content

Commit

Permalink
fixed optuna examples
Browse files Browse the repository at this point in the history
  • Loading branch information
screengreen committed Sep 5, 2024
1 parent d960ae7 commit 60e33cf
Show file tree
Hide file tree
Showing 5 changed files with 25 additions and 12 deletions.
1 change: 1 addition & 0 deletions examples/optimization/conditional_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
train_data, test_data = train_test_split(data, test_size=0.2, stratify=data["TARGET"], random_state=42)


# replacing default _sample function in OptunaTuner class with this function
def sample(optimization_search_space, trial, suggested_params):
trial_values = copy.copy(suggested_params)
trial_values["feature_fraction"] = trial.suggest_uniform("feature_fraction", low=0.5, high=1.0)
Expand Down
7 changes: 3 additions & 4 deletions examples/optimization/custom_search_space.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,8 @@
from sklearn.model_selection import train_test_split

from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.ml_algo.tuning.base import Distribution
from lightautoml.ml_algo.tuning.base import SearchSpace
from lightautoml.tasks import Task
from lightautoml.ml_algo.tuning.base import Uniform


# load and prepare data
Expand All @@ -22,8 +21,8 @@
task=Task("binary"),
lgb_params={
"optimization_search_space": {
"feature_fraction": SearchSpace(Distribution.UNIFORM, low=0.5, high=1.0),
"min_sum_hessian_in_leaf": SearchSpace(Distribution.LOGUNIFORM, low=1e-3, high=10.0),
"feature_fraction": Uniform(low=0.5, high=1.0),
"min_sum_hessian_in_leaf": Uniform(low=1e-3, high=10.0, log=True),
}
},
)
Expand Down
3 changes: 1 addition & 2 deletions examples/optimization/sequential_parameter_search.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,12 @@
from lightautoml.automl.presets.tabular_presets import TabularAutoML
from lightautoml.tasks import Task


# load and prepare data
data = pd.read_csv("./data/sampled_app_train.csv")
train_data, test_data = train_test_split(data, test_size=0.2, stratify=data["TARGET"], random_state=42)


def sample(optimization_search_space, trial, suggested_params):
def sample(trial, suggested_params):
trial_values = copy.copy(suggested_params)

for feature_fraction in range(10):
Expand Down
9 changes: 8 additions & 1 deletion lightautoml/ml_algo/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,7 @@
from typing import Any
from typing import Dict
from typing import List
from typing import Callable
from typing import Optional
from typing import Sequence
from typing import Tuple
Expand Down Expand Up @@ -48,7 +49,13 @@ class MLAlgo(ABC):
"""

_default_params: Dict = {}
optimization_search_space: Dict = {}

# Dict is a default search space representation that is used for simple cases
# Callable is used for complex cases like conditional search space as described in
# LightAutoML/examples/optimization/conditional_parameters.py
# Called in _get_objective function in OptunaTuner class
optimization_search_space: Union[Dict, Callable] = {}

# TODO: add checks here
_fit_checks: Tuple = ()
_transform_checks: Tuple = ()
Expand Down
17 changes: 12 additions & 5 deletions lightautoml/ml_algo/tuning/optuna.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,11 +273,18 @@ def objective(trial: optuna.trial.Trial) -> float:
estimated_n_trials=estimated_n_trials,
)

_ml_algo.params = self._sample(
trial=trial,
optimization_search_space=optimization_search_space,
suggested_params=_ml_algo.init_params_on_input(train_valid_iterator),
)
if callable(optimization_search_space):
_ml_algo.params = optimization_search_space(
trial=trial,
optimization_search_space=None,
suggested_params=_ml_algo.init_params_on_input(train_valid_iterator),
)
else:
_ml_algo.params = self._sample(
trial=trial,
optimization_search_space=optimization_search_space,
suggested_params=_ml_algo.init_params_on_input(train_valid_iterator),
)

output_dataset = _ml_algo.fit_predict(train_valid_iterator=train_valid_iterator)

Expand Down

0 comments on commit 60e33cf

Please sign in to comment.