Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Learner): support marshal property #993

Merged
merged 49 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8fb3818
feat(Learner): support bundling property
sebffischer Jan 25, 2024
851e64c
fix(bundle): always (un)bundle for callr encapsulation
sebffischer Jan 25, 2024
1bcc0c7
fix: bundle cannot be present twice in properties
sebffischer Jan 26, 2024
bfb568f
typo
sebffischer Jan 30, 2024
606a5a9
refactor bundling
sebffischer Jan 31, 2024
ea651e2
bundle property must be manually set
sebffischer Jan 31, 2024
7d767de
fix tests
sebffischer Jan 31, 2024
291d450
better docs
sebffischer Jan 31, 2024
ff31f66
fix one more test
sebffischer Jan 31, 2024
2152b75
really fix test
sebffischer Jan 31, 2024
9bd2efe
public methods
sebffischer Jan 31, 2024
cfbab13
refactor
sebffischer Feb 1, 2024
f98424b
Update R/Measure.R
sebffischer Feb 1, 2024
50919d2
Update man-roxygen/param_learner_properties.R
sebffischer Feb 14, 2024
77af5ee
better marshal behavior
sebffischer Feb 20, 2024
c9c9c6a
Update R/Measure.R
sebffischer Feb 21, 2024
86e9163
Update R/Measure.R
sebffischer Feb 21, 2024
e2b1b54
docs
sebffischer Feb 21, 2024
c73f892
better approach
sebffischer Feb 22, 2024
e0c53ea
docs
sebffischer Feb 22, 2024
ffffc30
add clone argument and optimize worker
sebffischer Feb 22, 2024
432971f
optimization
sebffischer Feb 22, 2024
f9b33ea
add marshal property to regr.debug and remove lily
sebffischer Feb 26, 2024
d6ceb1c
inplace
sebffischer Mar 4, 2024
cdf603f
some more fixes
sebffischer Mar 5, 2024
ea2d75c
rename
sebffischer Apr 9, 2024
afbb6b5
...
sebffischer Apr 9, 2024
26c6c81
marshal is property of classif.debug
sebffischer Apr 9, 2024
2f5e685
fix printer and autotest
sebffischer Apr 9, 2024
f8943e7
Merge branch 'main' into bundle
sebffischer Apr 9, 2024
51e5c5f
...
sebffischer Apr 9, 2024
ad8cfb5
typo
sebffischer Apr 9, 2024
9581425
docs
sebffischer Apr 9, 2024
72733a9
refactor
sebffischer Apr 10, 2024
dfc345e
inplace marshal for ResultData
sebffischer Apr 10, 2024
c979ba8
fix class of marshaled classif debug
sebffischer Apr 17, 2024
7f637c8
add class to learner state for marshaling
sebffischer Apr 17, 2024
688f34c
...
sebffischer Apr 17, 2024
3a89b38
Merge branch 'main' into bundle
sebffischer Apr 17, 2024
626a3df
...
sebffischer Apr 17, 2024
5ada07c
typo
sebffischer Apr 17, 2024
3d64b1a
cleanup marshaling
sebffischer Apr 22, 2024
af67642
more cleanup
sebffischer Apr 22, 2024
0e82181
add test
sebffischer Apr 22, 2024
6e87384
more cleanup
sebffischer Apr 22, 2024
ef59a0b
better docs
sebffischer Apr 22, 2024
3707f8c
fix tests for at and glrn
sebffischer Apr 22, 2024
cbf2700
skip some tests until new versions are released
sebffischer Apr 22, 2024
c497a7d
fix test helpers
sebffischer Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -183,6 +183,7 @@ Collate:
'helper_hashes.R'
'helper_print.R'
'install_pkgs.R'
'marshal.R'
'mlr_sugar.R'
'mlr_test_helpers.R'
'partition.R'
Expand Down
13 changes: 13 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ S3method(fix_factor_levels,data.table)
S3method(head,Task)
S3method(is_missing_prediction_data,PredictionDataClassif)
S3method(is_missing_prediction_data,PredictionDataRegr)
S3method(marshal_model,classif.debug_model)
S3method(marshal_model,default)
S3method(marshal_model,learner_state)
S3method(partition,Task)
S3method(partition,TaskClassif)
S3method(partition,TaskRegr)
Expand All @@ -95,6 +98,7 @@ S3method(print,PredictionData)
S3method(print,benchmark_grid)
S3method(print,bmr_aggregate)
S3method(print,bmr_score)
S3method(print,marshaled)
S3method(print,rr_score)
S3method(rd_info,Learner)
S3method(rd_info,Measure)
Expand All @@ -104,6 +108,9 @@ S3method(set_threads,default)
S3method(set_threads,list)
S3method(summary,Task)
S3method(tail,Task)
S3method(unmarshal_model,classif.debug_model_marshaled)
S3method(unmarshal_model,default)
S3method(unmarshal_model,learner_state_marshaled)
export(BenchmarkResult)
export(DataBackend)
export(DataBackendDataTable)
Expand Down Expand Up @@ -207,9 +214,14 @@ export(default_measures)
export(extract_pkgs)
export(filter_prediction_data)
export(install_pkgs)
export(is_marshaled_model)
export(is_missing_prediction_data)
export(learner_marshal)
export(learner_marshaled)
export(learner_unmarshal)
export(lrn)
export(lrns)
export(marshal_model)
export(mlr_learners)
export(mlr_measures)
export(mlr_reflections)
Expand All @@ -227,6 +239,7 @@ export(tgen)
export(tgens)
export(tsk)
export(tsks)
export(unmarshal_model)
import(checkmate)
import(data.table)
import(mlr3misc)
Expand Down
3 changes: 3 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,8 @@
# mlr3 (development version)

* Feat: added support for `"marshal"` property, which allows learners to process
models so they can be serialized. This happens automatically during `resample()`
and `benchmark()`.
* Log encapsulated errors and warnings with the `lgr` package.

# mlr3 0.18.0
Expand Down
14 changes: 14 additions & 0 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,20 @@ BenchmarkResult = R6Class("BenchmarkResult",
invisible(self)
},

#' @description
#' Marshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
private$.data$marshal(...)
},
#' @description
#' Unmarshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
private$.data$unmarshal(...)
},

#' @description
#' Returns a table with one row for each resampling iteration, including
Expand Down
11 changes: 7 additions & 4 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ HotstartStack = R6Class("HotstartStack",
add = function(learners) {
learners = assert_learners(as_learners(learners))

# check for models
if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}
walk(learners, function(learner) {
if (is.null(learner$model)) {
stopf("Learners must be trained before adding them to the hotstart stack.")
} else if (is_marshaled_model(learner$model)) {
stopf("Learners must be unmarshaled before adding them to the hotstart stack.")
}
})

if (!is.null(self$hotstart_threshold)) {
learners = keep(learners, function(learner) {
Expand Down
33 changes: 30 additions & 3 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Learner = R6Class("Learner",
#' @param ... (ignored).
print = function(...) {
catn(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else class(self$model)[1L]))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else if (is_marshaled_model(self$model)) "<marshaled>" else paste0(class(self$model)[1L])))
catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L)))
catn(str_indent("* Packages:", self$packages))
catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))))
Expand Down Expand Up @@ -243,6 +243,7 @@ Learner = R6Class("Learner",
test_row_ids = task$row_roles$test

learner_train(learner, task, train_row_ids = train_row_ids, test_row_ids = test_row_ids, mode = mode)
self$model = unmarshal_model(model = self$state$model, inplace = TRUE)

# store data prototype
proto = task$data(rows = integer())
Expand Down Expand Up @@ -279,6 +280,10 @@ Learner = R6Class("Learner",
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
}

if (is_marshaled_model(self$model)) {
stopf("Cannot predict, Learner '%s' has not been unmarshaled yet", self$id)
}

if (isTRUE(self$parallel_predict) && nbrOfWorkers() > 1L) {
row_ids = row_ids %??% task$row_ids
chunked = chunk_vector(row_ids, n_chunks = nbrOfWorkers(), shuffle = FALSE)
Expand Down Expand Up @@ -388,7 +393,6 @@ Learner = R6Class("Learner",
self$state$model
},


#' @field timings (named `numeric(2)`)\cr
#' Elapsed time in seconds for the steps `"train"` and `"predict"`.
#' Measured via [mlr3misc::encapsulate()].
Expand Down Expand Up @@ -541,7 +545,6 @@ Learner = R6Class("Learner",
)
)


#' @export
rd_info.Learner = function(obj, ...) {
x = c("",
Expand Down Expand Up @@ -576,3 +579,27 @@ default_values.Learner = function(x, search_space, task, ...) { # nolint
# format_list_item.Learner = function(x, ...) { # nolint
# sprintf("<lrn:%s>", x$id)
# }


#' @export
marshal_model.learner_state = function(model, inplace = FALSE, ...) {
if (is.null(model$model)) {
return(model)
}
mm = marshal_model(model$model, inplace = inplace, ...)
if (!is_marshaled_model(mm)) {
return(model)
}
model$model = mm
structure(list(
marshaled = model,
packages = "mlr3"
), class = c("learner_state_marshaled", "list_marshaled", "marshaled"))
}

#' @export
unmarshal_model.learner_state_marshaled = function(model, inplace = FALSE, ...) {
mm = model$marshaled
mm$model = unmarshal_model(mm$model, inplace = inplace, ...)
return(mm)
}
86 changes: 66 additions & 20 deletions R/LearnerClassifDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#' \item{warning_train:}{Probability to signal a warning during train.}
#' \item{x:}{Numeric tuning parameter. Has no effect.}
#' \item{iter:}{Integer parameter for testing hotstarting.}
#' \item{count_marshaling:}{If `TRUE`, `marshal_model` will increase the `marshal_count` by 1 each time it is called. The default is `FALSE`.}
#' }
#' Note that segfaults may not be triggered reliably on your operating system.
#' Also note that if they work as intended, they will tear down your R session immediately!
Expand All @@ -49,39 +50,62 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
param_set = ps(
error_predict = p_dbl(0, 1, default = 0, tags = "predict"),
error_train = p_dbl(0, 1, default = 0, tags = "train"),
message_predict = p_dbl(0, 1, default = 0, tags = "predict"),
message_train = p_dbl(0, 1, default = 0, tags = "train"),
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
save_tasks = p_lgl(default = FALSE, tags = c("train", "predict")),
segfault_predict = p_dbl(0, 1, default = 0, tags = "predict"),
segfault_train = p_dbl(0, 1, default = 0, tags = "train"),
sleep_train = p_uty(tags = "train"),
sleep_predict = p_uty(tags = "predict"),
threads = p_int(1L, tags = c("train", "threads")),
warning_predict = p_dbl(0, 1, default = 0, tags = "predict"),
warning_train = p_dbl(0, 1, default = 0, tags = "train"),
x = p_dbl(0, 1, tags = "train"),
iter = p_int(1, default = 1, tags = c("train", "hotstart")),
count_marshaling = p_lgl(default = FALSE, tags = "train")
)
super$initialize(
id = "classif.debug",
param_set = param_set,
feature_types = c("logical", "integer", "numeric", "character", "factor", "ordered"),
predict_types = c("response", "prob"),
param_set = ps(
error_predict = p_dbl(0, 1, default = 0, tags = "predict"),
error_train = p_dbl(0, 1, default = 0, tags = "train"),
message_predict = p_dbl(0, 1, default = 0, tags = "predict"),
message_train = p_dbl(0, 1, default = 0, tags = "train"),
predict_missing = p_dbl(0, 1, default = 0, tags = "predict"),
predict_missing_type = p_fct(c("na", "omit"), default = "na", tags = "predict"),
save_tasks = p_lgl(default = FALSE, tags = c("train", "predict")),
segfault_predict = p_dbl(0, 1, default = 0, tags = "predict"),
segfault_train = p_dbl(0, 1, default = 0, tags = "train"),
sleep_train = p_uty(tags = "train"),
sleep_predict = p_uty(tags = "predict"),
threads = p_int(1L, tags = c("train", "threads")),
warning_predict = p_dbl(0, 1, default = 0, tags = "predict"),
warning_train = p_dbl(0, 1, default = 0, tags = "train"),
x = p_dbl(0, 1, tags = "train"),
iter = p_int(1, default = 1, tags = c("train", "hotstart"))
),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward"),
properties = c("twoclass", "multiclass", "missings", "hotstart_forward", "marshal"),
man = "mlr3::mlr_learners_classif.debug",
data_formats = c("data.table", "Matrix"),
label = "Debug Learner for Classification"
)
},
#' @description
#' Marshal the learner's model.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
learner_marshal(.learner = self, ...)
},
#' @description
#' Unmarshal the learner's model.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_unmarshal(.learner = self, ...)
}
),
active = list(
#' @field marshaled (logical(1))\cr
#' Whether the learner has been marshaled.
marshaled = function() {
learner_marshaled(self)
}
),

private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
pv$count_marshaling = pv$count_marshaling %??% FALSE
roll = function(name) {
name %in% names(pv) && pv[[name]] > runif(1L)
}
Expand Down Expand Up @@ -110,6 +134,10 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,
model$task_train = task$clone(deep = TRUE)
}

if (isTRUE(pv$count_marshaling)) {
model$marshal_count = 0L
}

set_class(model, "classif.debug_model")
},

Expand Down Expand Up @@ -193,3 +221,21 @@ LearnerClassifDebug = R6Class("LearnerClassifDebug", inherit = LearnerClassif,

#' @include mlr_learners.R
mlr_learners$add("classif.debug", function() LearnerClassifDebug$new())

#' @export
#' @method marshal_model classif.debug_model
marshal_model.classif.debug_model = function(model, inplace = FALSE, ...) {
if (!is.null(model$marshal_count)) {
model$marshal_count = model$marshal_count + 1
}
structure(list(
marshaled = model, packages = "mlr3"),
class = c("classif.debug_model_marshaled", "marshaled")
)
}

#' @export
#' @method unmarshal_model classif.debug_model_marshaled
unmarshal_model.classif.debug_model_marshaled = function(model, inplace = FALSE, ...) {
model$marshaled
}
1 change: 0 additions & 1 deletion R/LearnerRegrDebug.R
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@ LearnerRegrDebug = R6Class("LearnerRegrDebug", inherit = LearnerRegr,
)
}
),

private = list(
.train = function(task) {
pv = self$param_set$get_values(tags = "train")
Expand Down
4 changes: 4 additions & 0 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ Measure = R6Class("Measure",
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction)


if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}
Expand All @@ -184,6 +185,9 @@ Measure = R6Class("Measure",
if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && is_marshaled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is in marshaled form", self$id)
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
Expand Down
15 changes: 15 additions & 0 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,21 @@ ResampleResult = R6Class("ResampleResult",
#' the object in its previous state.
discard = function(backends = FALSE, models = FALSE) {
private$.data$discard(backends = backends, models = models)
},

#' @description
#' Marshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
private$.data$marshal(...)
},
#' @description
#' Unmarshals all stored models.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
private$.data$unmarshal(...)
}
),

Expand Down
21 changes: 21 additions & 0 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,27 @@ ResultData = R6Class("ResultData",
invisible(self)
},

#' @description
#' Marshals all stored learner models.
#' This will do nothing to models that are already marshaled.
#' @param ... (any)\cr
#' Additional arguments passed to [`marshal_model()`].
marshal = function(...) {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, function(x) marshal_state_if_model(.state = x, inplace = TRUE, ...))]
invisible(self)
},
#' @description
#' Unmarshals all stored learner models.
#' This will do nothing to models which are not marshaled.
#' @param ... (any)\cr
#' Additional arguments passed to [`unmarshal_model()`].
unmarshal = function(...) {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, function(x) unmarshal_state_if_model(.state = x, inplace = TRUE, ...))]
invisible(self)
},

#' @description
#' Shrinks the object by discarding parts of the stored data.
#'
Expand Down
Loading
Loading