Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat(Learner): support marshal property #993

Merged
merged 49 commits into from
Apr 23, 2024
Merged
Show file tree
Hide file tree
Changes from 12 commits
Commits
Show all changes
49 commits
Select commit Hold shift + click to select a range
8fb3818
feat(Learner): support bundling property
sebffischer Jan 25, 2024
851e64c
fix(bundle): always (un)bundle for callr encapsulation
sebffischer Jan 25, 2024
1bcc0c7
fix: bundle cannot be present twice in properties
sebffischer Jan 26, 2024
bfb568f
typo
sebffischer Jan 30, 2024
606a5a9
refactor bundling
sebffischer Jan 31, 2024
ea651e2
bundle property must be manually set
sebffischer Jan 31, 2024
7d767de
fix tests
sebffischer Jan 31, 2024
291d450
better docs
sebffischer Jan 31, 2024
ff31f66
fix one more test
sebffischer Jan 31, 2024
2152b75
really fix test
sebffischer Jan 31, 2024
9bd2efe
public methods
sebffischer Jan 31, 2024
cfbab13
refactor
sebffischer Feb 1, 2024
f98424b
Update R/Measure.R
sebffischer Feb 1, 2024
50919d2
Update man-roxygen/param_learner_properties.R
sebffischer Feb 14, 2024
77af5ee
better marshal behavior
sebffischer Feb 20, 2024
c9c9c6a
Update R/Measure.R
sebffischer Feb 21, 2024
86e9163
Update R/Measure.R
sebffischer Feb 21, 2024
e2b1b54
docs
sebffischer Feb 21, 2024
c73f892
better approach
sebffischer Feb 22, 2024
e0c53ea
docs
sebffischer Feb 22, 2024
ffffc30
add clone argument and optimize worker
sebffischer Feb 22, 2024
432971f
optimization
sebffischer Feb 22, 2024
f9b33ea
add marshal property to regr.debug and remove lily
sebffischer Feb 26, 2024
d6ceb1c
inplace
sebffischer Mar 4, 2024
cdf603f
some more fixes
sebffischer Mar 5, 2024
ea2d75c
rename
sebffischer Apr 9, 2024
afbb6b5
...
sebffischer Apr 9, 2024
26c6c81
marshal is property of classif.debug
sebffischer Apr 9, 2024
2f5e685
fix printer and autotest
sebffischer Apr 9, 2024
f8943e7
Merge branch 'main' into bundle
sebffischer Apr 9, 2024
51e5c5f
...
sebffischer Apr 9, 2024
ad8cfb5
typo
sebffischer Apr 9, 2024
9581425
docs
sebffischer Apr 9, 2024
72733a9
refactor
sebffischer Apr 10, 2024
dfc345e
inplace marshal for ResultData
sebffischer Apr 10, 2024
c979ba8
fix class of marshaled classif debug
sebffischer Apr 17, 2024
7f637c8
add class to learner state for marshaling
sebffischer Apr 17, 2024
688f34c
...
sebffischer Apr 17, 2024
3a89b38
Merge branch 'main' into bundle
sebffischer Apr 17, 2024
626a3df
...
sebffischer Apr 17, 2024
5ada07c
typo
sebffischer Apr 17, 2024
3d64b1a
cleanup marshaling
sebffischer Apr 22, 2024
af67642
more cleanup
sebffischer Apr 22, 2024
0e82181
add test
sebffischer Apr 22, 2024
6e87384
more cleanup
sebffischer Apr 22, 2024
ef59a0b
better docs
sebffischer Apr 22, 2024
3707f8c
fix tests for at and glrn
sebffischer Apr 22, 2024
cbf2700
skip some tests until new versions are released
sebffischer Apr 22, 2024
c497a7d
fix test helpers
sebffischer Apr 23, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -74,7 +74,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: false
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = TRUE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
Collate:
'mlr_reflections.R'
'BenchmarkResult.R'
Expand All @@ -90,6 +90,7 @@ Collate:
'mlr_learners.R'
'LearnerClassifDebug.R'
'LearnerClassifFeatureless.R'
'LearnerClassifLily.R'
'LearnerClassifRpart.R'
'LearnerRegr.R'
'LearnerRegrDebug.R'
Expand Down Expand Up @@ -183,6 +184,7 @@ Collate:
'helper_hashes.R'
'helper_print.R'
'install_pkgs.R'
'marshal.R'
'mlr_sugar.R'
'mlr_test_helpers.R'
'partition.R'
Expand Down
11 changes: 11 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,8 @@ S3method(fix_factor_levels,data.table)
S3method(head,Task)
S3method(is_missing_prediction_data,PredictionDataClassif)
S3method(is_missing_prediction_data,PredictionDataRegr)
S3method(marshal_model,classif_lily_model)
S3method(marshal_model,default)
S3method(partition,Task)
S3method(partition,TaskClassif)
S3method(partition,TaskRegr)
Expand All @@ -104,6 +106,8 @@ S3method(set_threads,default)
S3method(set_threads,list)
S3method(summary,Task)
S3method(tail,Task)
S3method(unmarshal_model,classif_lily_model_marshalled)
S3method(unmarshal_model,default)
export(BenchmarkResult)
export(DataBackend)
export(DataBackendDataTable)
Expand All @@ -113,6 +117,7 @@ export(Learner)
export(LearnerClassif)
export(LearnerClassifDebug)
export(LearnerClassifFeatureless)
export(LearnerClassifLily)
export(LearnerClassifRpart)
export(LearnerRegr)
export(LearnerRegrDebug)
Expand Down Expand Up @@ -208,8 +213,13 @@ export(extract_pkgs)
export(filter_prediction_data)
export(install_pkgs)
export(is_missing_prediction_data)
export(learner_marshal)
export(learner_marshalled)
export(learner_unmarshal)
export(lrn)
export(lrns)
export(marshal_model)
export(marshalled_model)
export(mlr_learners)
export(mlr_measures)
export(mlr_reflections)
Expand All @@ -227,6 +237,7 @@ export(tgen)
export(tgens)
export(tsk)
export(tsks)
export(unmarshal_model)
import(checkmate)
import(data.table)
import(mlr3misc)
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
# mlr3 (development version)

* Feat: added support for `"marshal"` property, which allows learners to process
models so they can be serialized. This happens automatically during `resample()`
and `benchmark()`. The naming was inspired by the {marshal} package.

# mlr3 0.17.2

* Skip new `data.table` tests on mac.
Expand Down
10 changes: 10 additions & 0 deletions R/BenchmarkResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,16 @@ BenchmarkResult = R6Class("BenchmarkResult",
invisible(self)
},

#' @description
#' marshals all stored models.
marshal = function() {
private$.data$marshal()
},
#' @description
#' Unmarshals all stored models.
unmarshal = function() {
private$.data$unmarshal()
},

#' @description
#' Returns a table with one row for each resampling iteration, including
Expand Down
11 changes: 7 additions & 4 deletions R/HotstartStack.R
Original file line number Diff line number Diff line change
Expand Up @@ -85,10 +85,13 @@ HotstartStack = R6Class("HotstartStack",
add = function(learners) {
learners = assert_learners(as_learners(learners))

# check for models
if (any(map_lgl(learners, function(learner) is.null(learner$state$model)))) {
stopf("Learners must be trained before adding them to the hotstart stack.")
}
walk(learners, function(learner) {
if (is.null(learner$model)) {
stopf("Learners must be trained before adding them to the hotstart stack.")
} else if (marshalled_model(learner$model)) {
stopf("Learners must be unmarshalled before adding them to the hotstart stack.")
}
})

if (!is.null(self$hotstart_threshold)) {
learners = keep(learners, function(learner) {
Expand Down
8 changes: 5 additions & 3 deletions R/Learner.R
Original file line number Diff line number Diff line change
Expand Up @@ -184,7 +184,7 @@ Learner = R6Class("Learner",
#' @param ... (ignored).
print = function(...) {
catn(format(self), if (is.null(self$label) || is.na(self$label)) "" else paste0(": ", self$label))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else class(self$model)[1L]))
catn(str_indent("* Model:", if (is.null(self$model)) "-" else if (marshalled_model(self$model)) "<marshalled>" else paste0(class(self$model)[1L])))
catn(str_indent("* Parameters:", as_short_string(self$param_set$values, 1000L)))
catn(str_indent("* Packages:", self$packages))
catn(str_indent("* Predict Types: ", replace(self$predict_types, self$predict_types == self$predict_type, paste0("[", self$predict_type, "]"))))
Expand Down Expand Up @@ -279,6 +279,10 @@ Learner = R6Class("Learner",
stopf("Cannot predict, Learner '%s' has not been trained yet", self$id)
}

if (marshalled_model(self$model)) {
stopf("Cannot predict, Learner '%s' has not been unmarshalled yet", self$id)
}

if (isTRUE(self$parallel_predict) && nbrOfWorkers() > 1L) {
row_ids = row_ids %??% task$row_ids
chunked = chunk_vector(row_ids, n_chunks = nbrOfWorkers(), shuffle = FALSE)
Expand Down Expand Up @@ -388,7 +392,6 @@ Learner = R6Class("Learner",
self$state$model
},


#' @field timings (named `numeric(2)`)\cr
#' Elapsed time in seconds for the steps `"train"` and `"predict"`.
#' Measured via [mlr3misc::encapsulate()].
Expand Down Expand Up @@ -541,7 +544,6 @@ Learner = R6Class("Learner",
)
)


#' @export
rd_info.Learner = function(obj, ...) {
x = c("",
Expand Down
74 changes: 74 additions & 0 deletions R/LearnerClassifLily.R
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
#' @title Lily and Marshall
#'
#' @name mlr_learners_classif.lily
#' @include LearnerClassifDebug.R
#'
#' @description
#' This learner is just like [`LearnerClassifDebug`], but can be marshalled.
#' When the `count_marshalling` parameter is `TRUE`, the model contains a `marshal_count` that will be increased
#' by 1, each time `marshal_model` is called.
#'
#' @templateVar id classif.lily
#' @template learner
#'
#' @export
LearnerClassifLily = R6Class("LearnerClassifLily",
inherit = LearnerClassifDebug,
public = list(
#' @description
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function() {
super$initialize()
self$param_set$add(ps(count_marshalling = p_lgl(tags = c("train", "required"))))
self$param_set$values$count_marshalling = FALSE
self$properties = sort(c("marshal", self$properties))
self$man = "mlr3::mlr_learners_classif.lily"
self$label = "Lily Learner"
self$id = "classif.lily"
},
#' @description
#' Marshals the learner.
marshal = function() {
learner_marshal(self)
},
#' @description
#' Unmarshal the learner.
unmarshal = function() {
learner_unmarshal(self)
}
),
active = list(
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
#' @field marshalled (logical(1))\cr
#' Whether the learner has been marshalled.
marshalled = function() {
learner_marshalled(self)
}
),
private = list(
.train = function(task) {
model = super$.train(task)
if (self$param_set$values$count_marshalling) {
model$marshal_count = 0L
}
class(model) = "classif_lily_model"
return(model)
}
)
)

#' @include mlr_learners.R
mlr_learners$add("classif.lily", function() LearnerClassifLily$new())

#' @export
marshal_model.classif_lily_model = function(model, ...) {
if (!is.null(model$marshal_count)) {
model$marshal_count = model$marshal_count + 1
}
newclass = c("classif_lily_model_marshalled", "marshalled")
structure(list(model), class = newclass)
}

#' @export
unmarshal_model.classif_lily_model_marshalled = function(model, ...) {
model[[1L]]
}
5 changes: 5 additions & 0 deletions R/Measure.R
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,8 @@ Measure = R6Class("Measure",
assert_measure(self, task = task, learner = learner)
assert_prediction(prediction)

# FIXME: if self has property model check that not marshalled
sebffischer marked this conversation as resolved.
Show resolved Hide resolved

if ("requires_task" %in% self$properties && is.null(task)) {
stopf("Measure '%s' requires a task", self$id)
}
Expand All @@ -184,6 +186,9 @@ Measure = R6Class("Measure",
if ("requires_model" %in% self$properties && (is.null(learner) || is.null(learner$model))) {
stopf("Measure '%s' requires the trained model", self$id)
}
if ("requires_model" %in% self$properties && marshalled_model(learner$model)) {
stopf("Measure '%s' requires the trained model, but model is un marshalled form", self$id)
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
}

if ("requires_train_set" %in% self$properties && is.null(train_set)) {
stopf("Measure '%s' requires the train_set", self$id)
Expand Down
11 changes: 11 additions & 0 deletions R/ResampleResult.R
Original file line number Diff line number Diff line change
Expand Up @@ -222,6 +222,17 @@ ResampleResult = R6Class("ResampleResult",
#' the object in its previous state.
discard = function(backends = FALSE, models = FALSE) {
private$.data$discard(backends = backends, models = models)
},

#' @description
#' marshals all stored learner models.
marshal = function() {
private$.data$marshal()
},
#' @description
#' Unmarshals all stored learner models.
unmarshal = function() {
private$.data$unmarshal()
}
),

Expand Down
15 changes: 15 additions & 0 deletions R/ResultData.R
Original file line number Diff line number Diff line change
Expand Up @@ -242,6 +242,21 @@ ResultData = R6Class("ResultData",
invisible(self)
},

#' @description
#' Marshals all stored learner models.
marshal = function() {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, marshal_state)]
invisible(self)
},
#' @description
#' Unmarshals all stored learner models.
unmarshal = function() {
learner_state = NULL
self$data$fact[, learner_state := lapply(learner_state, unmarshal_state)]
invisible(self)
},

#' @description
#' Shrinks the object by discarding parts of the stored data.
#'
Expand Down
12 changes: 10 additions & 2 deletions R/benchmark.R
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@
#' @template param_encapsulate
#' @template param_allow_hotstart
#' @template param_clone
#' @template param_unmarshal
#'
#' @return [BenchmarkResult].
#'
Expand Down Expand Up @@ -77,7 +78,7 @@
#' ## Get the training set of the 2nd iteration of the featureless learner on penguins
#' rr = bmr$aggregate()[learner_id == "classif.featureless"]$resample_result[[1]]
#' rr$resampling$train_set(2)
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling")) {
benchmark = function(design, store_models = FALSE, store_backends = TRUE, encapsulate = NA_character_, allow_hotstart = FALSE, clone = c("task", "learner", "resampling"), unmarshal = TRUE) {
assert_subset(clone, c("task", "learner", "resampling"))
assert_data_frame(design, min.rows = 1L)
assert_names(names(design), must.include = c("task", "learner", "resampling"))
Expand Down Expand Up @@ -196,5 +197,12 @@ benchmark = function(design, store_models = FALSE, store_backends = TRUE, encaps
lg$info("Finished benchmark")

set(grid, j = "mode", value = NULL)
BenchmarkResult$new(ResultData$new(grid, store_backends = store_backends))

result_data = ResultData$new(grid, store_backends = store_backends)

if (unmarshal && store_models) {
result_data$unmarshal()
}

BenchmarkResult$new(result_data)
}
3 changes: 2 additions & 1 deletion R/helper_exec.R
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ future_map = function(n, FUN, ..., MoreArgs = list()) {
future.apply::future_mapply(
FUN, ..., MoreArgs = MoreArgs, SIMPLIFY = FALSE, USE.NAMES = FALSE,
future.globals = FALSE, future.packages = "mlr3", future.seed = TRUE,
future.scheduling = scheduling, future.chunk.size = chunk_size, future.stdout = stdout)
future.scheduling = scheduling, future.chunk.size = chunk_size, future.stdout = stdout
)
}
}
Loading
Loading