Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: validation task #983

Closed
wants to merge 33 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
33 commits
Select commit Hold shift + click to select a range
2c77828
refactor: remove task prototype when resample
be-marc Dec 7, 2023
2164d4d
refactor: add option to store prototype
be-marc Dec 7, 2023
21ed459
fix: braket
be-marc Dec 7, 2023
cef8117
refactor: null
be-marc Dec 7, 2023
48f40fa
fix: browser
be-marc Dec 7, 2023
1c027cd
keep prototypes in state when store_models is TRUE
sebffischer Dec 8, 2023
8788355
feat(Learner): uses_test_set active binding
sebffischer Dec 14, 2023
b29d558
...
sebffischer Dec 14, 2023
c2726c9
...
sebffischer Jan 23, 2024
73247c4
...
sebffischer Jan 23, 2024
ddfa67d
...
sebffischer Jan 24, 2024
dd65441
Update R/Learner.R
sebffischer Jan 25, 2024
1dd12cc
Update R/Learner.R
sebffischer Jan 25, 2024
3d06488
...
sebffischer Jan 27, 2024
c58da30
allow cbinding test rows to task
sebffischer Jan 29, 2024
02fa3dd
add test
sebffischer Jan 30, 2024
3aabaef
allow to cbind test rows
sebffischer Feb 9, 2024
5b983ee
Update NEWS.md
sebffischer Feb 12, 2024
9e8ae85
avoid unnecessary sort
sebffischer Feb 12, 2024
9a0e954
work on test and holdout tas
sebffischer Feb 16, 2024
379e2f2
Merge branch 'main' into feat/train-predict
sebffischer Feb 16, 2024
4cb8b8e
BREAKING_CHANGE: test/holdout task replace test/holdout roles
sebffischer Feb 16, 2024
448296d
fix some issues regarding test task
sebffischer Feb 20, 2024
343f475
better news
sebffischer Feb 20, 2024
944adfc
pipelines dependency
sebffischer Feb 20, 2024
bdd54fb
uber hack for revdepcheck
sebffischer Feb 21, 2024
c465b7f
remove remotes
sebffischer Feb 21, 2024
29e21bd
optimization
sebffischer Feb 22, 2024
1c06fb7
refactor: partition method is now called divide
sebffischer Feb 22, 2024
fe54dac
comment hack
sebffischer Feb 22, 2024
3bd7284
rename test -> validation
sebffischer Mar 5, 2024
4c0f4c7
some progress
sebffischer Mar 19, 2024
d47a4ba
...
sebffischer Apr 9, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -189,6 +189,7 @@ Collate:
'predict.R'
'reexports.R'
'resample.R'
'set_inner_tuning.R'
'set_threads.R'
'task_converters.R'
'worker.R'
Expand Down
2 changes: 2 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@ S3method(print,rr_score)
S3method(rd_info,Learner)
S3method(rd_info,Measure)
S3method(rd_info,Task)
S3method(set_inner_tuning,LearnerRegrDebug)
S3method(set_threads,R6)
S3method(set_threads,default)
S3method(set_threads,list)
Expand Down Expand Up @@ -222,6 +223,7 @@ export(partition)
export(resample)
export(rsmp)
export(rsmps)
export(set_inner_tuning)
export(set_threads)
export(tgen)
export(tgens)
Expand Down
6 changes: 6 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,11 @@
# mlr3 (development version)

* Fix: A task's hash now takes the validation task (previously row roles 'test') into
account, which now ensures that hotstarting and usage of test rows works
together
* BREAKING_CHANGE: removes row rols 'test' and 'holdout'.
* Feat(Learner): Better support for validation and learner internal tuning.
* TODOOOO
* feat: dictionary conversion of `mlr_learners` respects prototype arguments
recently added in mlr3misc
* perf: skip unnecessary clone of learner's state in `resample()`
Expand Down
12 changes: 8 additions & 4 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,14 @@ HotstartStack = R6Class("HotstartStack",
add = function(learners) {
learners = assert_learners(as_learners(learners))

# check for models
if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}
walk(learners, function(learner) {
if (is.null(learner$state$model)) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}
if (!is.null(learner$state$param_vals$validate)) {
stopf("Hotstart learners that did validation is currently not supported.")
}
})

if (!is.null(self$hotstart_threshold)) {
learners = keep(learners, function(learner) {
Expand Down
50 changes: 40 additions & 10 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,13 @@
#' * `loglik(...)`: Extracts the log-likelihood (c.f. [stats::logLik()]).
#' This can be used in measures like [mlr_measures_aic] or [mlr_measures_bic].
#'
#' * `inner_valid_scores(...)`: Returns the inner validation error(s) of the model as named `list()` with
#' `numeric(1)` values.
#' Learners that have the `"validation"` property must implement this.

#' * `inner_tuning_values(...)`: Returns the inner tuned hyperparameters of the model as named `list()`.
#' Learners that have the `"tune"` property must implement this.
#' In case no values were tuned, an empty list should be returned.
#'
#' @section Setting Hyperparameters:
#'
Expand All @@ -83,6 +90,23 @@
#' lrn$param_set$add(paradox::ParamFct$new("foo", levels = c("a", "b")))
#' ```
#'
#' @section Validation:
#' Learners that can make use of an additional validation set (e.g. for early stopping) must:
#' * be annotated with the `"validation"` property
#' * implement the `$inner_valid_scores()` extractors (see section *Optional Extractors*)
#' * Add the `validate` parameter, which can be either `NULL`, a ratio, `"test"`, or `"inner_valid_task"`:
#' * `NULL`: no validation
#' * `ratio`: only proportion `1 - ratio` of the task is used for training and `ratio` is used for validation.
#' set in the task).
#' * `"test"` means that the `"test"` task is used.
#' **Warning**: This might lead to bias performance estimation.
#' This option is only available if the learner is being trained via [resample()], [benchmark()] or functions that
#' internally use them, e.g. [`mlr3tuning::tune`] or [`mlr3batchmark::batchmark()`].
#' This is especially useful for hyperparameter tuning, where one might want to use the same data for early
#' stopping and the evaluation of the hyperparameter configurations.
#' * `"inner_valid_task"` means that the inner validation task is used.
#' See the [`Task`] documentation for this.
#'
#' @template seealso_learner
#' @export
Learner = R6Class("Learner",
Expand All @@ -103,6 +127,11 @@ Learner = R6Class("Learner",
#' @template field_task_type
task_type = NULL,

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,

#' @field predict_types (`character()`)\cr
#' Stores the possible predict types the learner is capable of.
#' A complete list of candidate predict types, grouped by task type, is stored in [`mlr_reflections$learner_predict_types`][mlr_reflections].
Expand All @@ -113,11 +142,6 @@ Learner = R6Class("Learner",
#' A complete list of candidate feature types, grouped by task type, is stored in [`mlr_reflections$task_feature_types`][mlr_reflections].
feature_types = NULL,

#' @field properties (`character()`)\cr
#' Stores a set of properties/capabilities the learner has.
#' A complete list of candidate properties, grouped by task type, is stored in [`mlr_reflections$learner_properties`][mlr_reflections].
properties = NULL,

#' @field data_formats (`character()`)\cr
#' Supported data format, e.g. `"data.table"` or `"Matrix"`.
data_formats = NULL,
Expand Down Expand Up @@ -240,21 +264,27 @@ Learner = R6Class("Learner",
}

train_row_ids = if (!is.null(row_ids)) row_ids else task$row_roles$use
test_row_ids = task$row_roles$test

learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)
train_result = learner_train(learner, task, train_row_ids = train_row_ids, mode = mode)

# store data prototype
proto = task$data(rows = integer())
self$state$data_prototype = proto
self$state$task_prototype = proto

# store the task w/o the data
# In the case where the validation task was specified manually we are duplicating some information here
# but this is in the interest of consistency
if (!is.null(self$param_set$values$validate)) {
self$state = insert_named(self$state, list(
inner_valid_task_ids = train_result$inner_valid_task_ids,
inner_valid_task_hash = train_result$inner_valid_task_hash
))
}

self$state$train_task = task_rm_backend(task$clone(deep = TRUE))

invisible(self)
},

#' @description
#' Uses the information stored during `$train()` in `$state` to create a new [Prediction]
#' for a set of observations of the provided `task`.
Expand Down Expand Up @@ -388,7 +418,6 @@ Learner = R6Class("Learner",
self$state$model
},


#' @field timings (named `numeric(2)`)\cr
#' Elapsed time in seconds for the steps `"train"` and `"predict"`.
#' Measured via [mlr3misc::encapsulate()].
Expand Down Expand Up @@ -518,6 +547,7 @@ Learner = R6Class("Learner",
),

private = list(
.properties = NULL,
.encapsulate = NULL,
.fallback = NULL,
.predict_type = NULL,
Expand Down
105 changes: 91 additions & 14 deletions R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
#' \item{save_tasks:}{Saves input task in `model` slot during training and prediction.}
#' \item{threads:}{Number of threads to use. Has no effect.}
#' \item{x:}{Numeric tuning parameter. Has no effect.}
#' \item{validate:}{Whether to evaluate the response on the validation set. This parameter can be either `NULL`,
#' a ratio, `"test"`, or `"inner_valid_task"`.}
#' \item{response:}{Whether to evaluate }
#' }
#'
#' @templateVar id regr.debug
Expand All @@ -21,7 +24,7 @@
#' @template seealso_learner
#' @export
#' @examples
#' task = tsk("mtcars")
#' task = tsk("")
#' learner = lrn("regr.debug", save_tasks = TRUE)
#' learner$train(task, row_ids = 1:20)
#' prediction = learner$predict(task, row_ids = 21:32)
Expand All @@ -33,6 +36,13 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
check_response = crate(function(x) {
if (is.null(x)) return(TRUE)
if (isTRUE(all.equal(x, "tune"))) return(TRUE)
if (test_numeric(x, len = 1L, any.missing = FALSE)) return(TRUE)

"Must either be 'tune', a numeric value, or NULL."
})
super$initialize(
id = "regr.debug",
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
Expand All @@ -42,13 +52,37 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
save_tasks = p_lgl(default = FALSE, tags = c("train", "predict")),
threads = p_int(1L, tags = c("train", "threads")),
x = p_dbl(0, 1, tags = "train")
x = p_dbl(0, 1, tags = "train"),
validate = p_uty(default = NULL, tags = "train", custom_check = check_validate),
response = p_uty(default = NULL, tags = c("train", "inner_tuning"), custom_check = check_response)
),
properties = "missings",
properties = c("missings", "validation", "inner_tuning"),
man = "mlr3::mlr_learners_regr.debug",
data_formats = c("data.table", "Matrix"),
label = "Debug Learner for Regression"
)
},
#' @description
#' Retrieves the inner validation scores.
#' @return named `list()`
inner_valid_scores = function() {
if (is.null(self$model)) {
stopf("No model trained yet.")
}
if (is.null(self$model$inner_valid_scores)) {
stopf("No inner validation.")
}
self$model$inner_valid_scores
},
#' @description
#' Retrieves the inner tuned values.
#' In this case
#' @return named `list()`
inner_tuning_values = function() {
if (is.null(self$model)) {
stopf("No model trained yet.")
}
self$model$inner_tuning_values
}
),

Expand All @@ -57,30 +91,42 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
pv = self$param_set$get_values(tags = "train")
truth = task$truth()
model = list(
response = mean(truth),
se = sd(truth),
pid = Sys.getpid()
)

valid_truth = if (!is.null(pv$validate)) task$inner_valid_task$truth() else NULL
if (isTRUE(all.equal(pv$response, "tune"))) {
if (is.null(task$inner_valid_task)) {
stopf("Can only tune if a validation set is present.")
}
model$response = mean(c(truth, valid_truth))
} else {
model$response = mean(truth)
}

if (isTRUE(pv$save_tasks)) {
model$task_train = task$clone(deep = TRUE)
}
set_class(model, "regr.debug_model")
},

.predict = function(task) {
n = task$nrow
pv = self$param_set$get_values(tags = "predict")

if (isTRUE(pv$save_tasks)) {
self$state$model$task_predict = task$clone(deep = TRUE)
if (!is.null(pv$validate)) {
pred = private$.make_prediction(task$inner_valid_task, model, self$param_set$get_values(tags = "predict"))
model$inner_valid_scores = list(
mse = mean((pred$response - valid_truth)^2),
mae = mean(abs(pred$response - valid_truth))
)
}

set_class(model, "regr.debug_model")
},

.make_prediction = function(task, model, pv) {
prediction = named_list(mlr_reflections$learner_predict_types[["regr"]][[self$predict_type]])
missing_type = pv$predict_missing_type %??% "na"
n = task$nrow

for (pt in names(prediction)) {
value = rep.int(self$model[[pt]], n)
value = rep.int(model[[pt]], n)
if (!is.null(pv$predict_missing)) {
ii = sample.int(n, n * pv$predict_missing)
value = switch(missing_type,
Expand All @@ -92,11 +138,42 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
prediction[[pt]] = value
}


return(prediction)
},

.predict = function(task) {
pv = self$param_set$get_values(tags = "predict")

if (isTRUE(pv$save_tasks)) {
self$state$model$task_predict = task$clone(deep = TRUE)
}

private$.make_prediction(task, self$model, pv)

}
)
)

#' @export
set_inner_tuning.LearnerRegrDebug = function(learner, disable = FALSE, validate = NULL, response = NULL, ...) {
pv = insert_named(learner$param_set$values, list(response = response, validate = validate))

if (disable) {
if (isTRUE(all.equal(pv$response, "tune"))) {
stopf("Parameter 'response' of learner %s must not be 'tune' to disable inner tuning.", learner$id)
}
} else {
if (is.null(pv$validate)) {
stopf("Parameter 'validate' must be provided to enable inner tuning.")
}
if (isTRUE(all.equal(pv$response, "tune"))) {
stopf("Parameter 'response' of learner %s must not be 'tune' to disable inner tuning.", learner$id)
}
}
learner$param_set$values = pv
invisible(learner)
}


#' @include mlr_learners.R
mlr_learners$add("regr.debug", function() LearnerRegrDebug$new())
Loading