Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

support marshaling in GraphLearner #759

Merged
merged 22 commits into from
Jun 30, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -93,7 +93,7 @@ Config/testthat/edition: 3
Config/testthat/parallel: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.2.3
RoxygenNote: 7.2.3.9000
VignetteBuilder: knitr
Collate:
'Graph.R'
Expand Down
3 changes: 3 additions & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ S3method(as_pipeop,Filter)
S3method(as_pipeop,Learner)
S3method(as_pipeop,PipeOp)
S3method(as_pipeop,default)
S3method(marshal_model,graph_learner_model)
S3method(po,"NULL")
S3method(po,Filter)
S3method(po,Learner)
Expand All @@ -23,6 +24,7 @@ S3method(pos,list)
S3method(predict,Graph)
S3method(print,Multiplicity)
S3method(print,Selector)
S3method(unmarshal_model,graph_learner_model_marshalled)
export("%>>!%")
export("%>>%")
export(Graph)
Expand Down Expand Up @@ -148,6 +150,7 @@ import(mlr3)
import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(data.table,as.data.table)
importFrom(digest,digest)
importFrom(stats,setNames)
importFrom(utils,bibentry)
Expand Down
6 changes: 6 additions & 0 deletions R/Graph.R
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,8 @@
#' Whether to store intermediate results in the [`PipeOp`]'s `$.result` slot, mostly for debugging purposes. Default `FALSE`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties of the `Graph` is the union of all the properties of its [`PipeOp`]s.
#'
#' @section Methods:
#' * `ids(sorted = FALSE)` \cr
Expand Down Expand Up @@ -504,6 +506,10 @@ Graph = R6Class("Graph",
} else {
map(self$pipeops, "state")
}
},
properties = function(rhs) {
assert_ro_binding(rhs)
sort(unique(unlist(map(self$pipeops, "properties"))))
}
),

Expand Down
32 changes: 31 additions & 1 deletion R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -98,11 +98,17 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
assert_subset(task_type, mlr_reflections$task_types$type)

properties = mlr_reflections$learner_properties[[task_type]]

if ("marshal" %nin% graph$properties) {
propertiers = setdiff(properties, "marshal")
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
}

super$initialize(id = id, task_type = task_type,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = mlr_reflections$learner_properties[[task_type]],
properties = properties,
man = "mlr3pipelines::GraphLearner"
)

Expand Down Expand Up @@ -132,9 +138,18 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
if (length(last_pipeop_id) == 0) stop("No Learner PipeOp found.")
}
learner_model$base_learner(recursive - 1)
},
marshal = function() {
learner_marshal(self)
},
unmarshal = function() {
learner_unmarshal(self)
}
),
active = list(
marshalled = function(rhs) {
learner_marshalled(self)
},
hash = function() {
digest(list(class(self), self$id, self$graph$hash, private$.predict_type,
self$fallback$hash, self$parallel_predict), algo = "xxhash64")
Expand Down Expand Up @@ -189,6 +204,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
on.exit({self$graph$state = NULL})
self$graph$train(task)
state = self$graph$state
class(state) = c("graph_learner_model", class(state))
state
},
.predict = function(task) {
Expand Down Expand Up @@ -233,6 +249,20 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
)
)

#' @export
marshal_model.graph_learner_model = function(model, ...) {
model = map(model, marshal_model)
class(model) = c("graph_learner_model_marshalled", "list_marshalled", "marshalled")
model
}

#' @export
unmarshal_model.graph_learner_model_marshalled = function(model, ...) {
model = map(model, marshal_model)
class(model) = c("graph_learner_model", "list")
model
}

#' @export
as_learner.Graph = function(x, clone = FALSE, ...) {
GraphLearner$new(x, clone_graph = clone)
Expand Down
10 changes: 9 additions & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -135,6 +135,8 @@
#' and done, if requested, by the [`Graph`] backend itself; it should *not* be done explicitly by `private$.train()` or `private$.predict()`.
#' * `man` :: `character(1)`\cr
#' Identifying string of the help page that shows with `help()`.
#' * `properties` :: `character()`\cr
#' The properties that are this PipeOp has. See `mlr_reflections$pipeops$properties` for available values.
sebffischer marked this conversation as resolved.
Show resolved Hide resolved
#'
#' @section Methods:
#' * `train(input)`\cr
Expand Down Expand Up @@ -236,7 +238,7 @@ PipeOp = R6Class("PipeOp",
.result = NULL,
tags = NULL,

initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -246,6 +248,7 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

private$.properties = sort(assert_subset(properties, mlr_reflections$pipeops$properties))
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -411,6 +414,10 @@ PipeOp = R6Class("PipeOp",
}
}
private$.label
},
properties = function(rhs) {
assert_ro_binding(rhs)
private$.properties
}
),

Expand All @@ -429,6 +436,7 @@ PipeOp = R6Class("PipeOp",
}
value
},
.properties = NULL,
.train = function(input) stop("abstract"),
.predict = function(input) stop("abstract"),
.additional_phash_input = function() {
Expand Down
3 changes: 2 additions & 1 deletion R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -89,10 +89,11 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task
out_type = mlr_reflections$task_types[type, mult = "first"]$prediction
properties = if ("marshal" %in% learner$properties) "marshal" else character(0)
super$initialize(id, param_set = alist(private$.learner$param_set), param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = "NULL", predict = out_type),
tags = "learner", packages = learner$packages
tags = "learner", packages = learner$packages, properties = properties
)
}
),
Expand Down
4 changes: 3 additions & 1 deletion R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,8 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
type = private$.learner$task_type
task_type = mlr_reflections$task_types[type, mult = "first"]$task

properties =if ("marshal" %in% learner$properties) "marshal" else character(0)

private$.crossval_param_set = ParamSet$new(params = list(
ParamFct$new("method", levels = c("cv", "insample"), tags = c("train", "required")),
ParamInt$new("folds", lower = 2L, upper = Inf, tags = c("train", "required")),
Expand All @@ -137,7 +139,7 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
# in PipeOp ParamSets.
# private$.crossval_param_set$add_dep("folds", "method", CondEqual$new("cv")) # don't do this.

super$initialize(id, alist(private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"))
super$initialize(id, alist(private$.crossval_param_set, private$.learner$param_set), param_vals = param_vals, can_subset_cols = TRUE, task_type = task_type, tags = c("learner", "ensemble"), properties = properties)
}

),
Expand Down
5 changes: 3 additions & 2 deletions R/PipeOpTaskPreproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -169,7 +169,7 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",

public = list(
initialize = function(id, param_set = ParamSet$new(), param_vals = list(), can_subset_cols = TRUE,
packages = character(0), task_type = "Task", tags = NULL, feature_types = mlr_reflections$task_feature_types) {
packages = character(0), task_type = "Task", tags = NULL, feature_types = mlr_reflections$task_feature_types, properties = character(0)) {
if (can_subset_cols) {
acp = ParamUty$new("affect_columns", custom_check = check_function_or_null, default = selector_all(), tags = "train")
if (inherits(param_set, "ParamSet")) {
Expand All @@ -183,7 +183,8 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
super$initialize(id = id, param_set = param_set, param_vals = param_vals,
input = data.table(name = "input", train = task_type, predict = task_type),
output = data.table(name = "output", train = task_type, predict = task_type),
packages = packages, tags = c(tags, "data transform")
packages = packages, tags = c(tags, "data transform"),
properties = properties
)
}
),
Expand Down
2 changes: 2 additions & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,8 @@ register_mlr3 = function() {
c("abstract", "meta", "missings", "feature selection", "imbalanced data",
"data transform", "target transform", "ensemble", "robustify", "learner", "encode",
"multiplicity")))

x$pipeops$properties = "marshal"
}

.onLoad = function(libname, pkgname) { # nocov start
Expand Down
2 changes: 2 additions & 0 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

2 changes: 2 additions & 0 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

11 changes: 11 additions & 0 deletions tests/testthat/test_GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -569,4 +569,15 @@ test_that("GraphLearner hashes", {

})

test_that("marshal", {
task = tsk("iris")
po_lily = as_pipeop(lrn("classif.lily"))
graph = as_graph(po_lily)
glrn = as_learner(graph)
expect_true("marshal" %in% glrn$properties)

# als checks that it is marshallable
expect_learner(glrn, task)

expect_false("marshal" %in% as_graph(lrn("regr.featureless"))$properties)
})
Loading