Skip to content

Commit

Permalink
feat: preprocess test_task in graph
Browse files Browse the repository at this point in the history
  • Loading branch information
sebffischer committed Feb 16, 2024
1 parent 044762e commit 544ae50
Show file tree
Hide file tree
Showing 88 changed files with 564 additions and 455 deletions.
4 changes: 3 additions & 1 deletion DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -87,13 +87,15 @@ Suggests:
methods,
vtreat,
future
Remotes:
mlr-org/mlr3@feat/train-predict
ByteCompile: true
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
VignetteBuilder: knitr
Collate:
'Graph.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ import(mlr3)
import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(data.table,as.data.table)
importFrom(digest,digest)
importFrom(stats,setNames)
importFrom(utils,bibentry)
Expand Down
20 changes: 18 additions & 2 deletions R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = mlr_reflections$learner_properties[[task_type]],
properties = setdiff(mlr_reflections$learner_properties[[task_type]], "uses_test_task"),
man = "mlr3pipelines::GraphLearner"
)

Expand Down Expand Up @@ -173,6 +173,13 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
),
private = list(
.contingent_properties = function() {
if (some(self$graph$pipeops, function(p) "uses_test_task" %in% p$properties)) {
"uses_test_task"
} else {
character(0)
}
},
.graph = NULL,
deep_clone = function(name, value) {
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
Expand All @@ -186,7 +193,16 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
},

.train = function(task) {
on.exit({self$graph$state = NULL})
if (!"uses_test_task" %in% self$properties) {
# remove the test information unless needed, so it is not preprocessed unnecessarily
on.exit({
prev_test_task = task$test_task
on.exit({
task$test_task = prev_test_task
})
}, add = TRUE)
}
on.exit({self$graph$state = NULL}, add = TRUE)
self$graph$train(task)
state = self$graph$state
state
Expand Down
17 changes: 16 additions & 1 deletion R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -236,7 +236,7 @@ PipeOp = R6Class("PipeOp",
.result = NULL,
tags = NULL,

initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract") {
initialize = function(id, param_set = ParamSet$new(), param_vals = list(), input, output, packages = character(0), tags = "abstract", properties = character(0)) {
if (inherits(param_set, "ParamSet")) {
private$.param_set = assert_param_set(param_set)
private$.param_set_source = NULL
Expand All @@ -246,6 +246,7 @@ PipeOp = R6Class("PipeOp",
}
self$id = assert_string(id)

private$.properties = sort(assert_subset(properties, mlr_reflections$pipeops$properties))
self$param_set$values = insert_named(self$param_set$values, param_vals)
self$input = assert_connection_table(input)
self$output = assert_connection_table(output)
Expand Down Expand Up @@ -335,6 +336,16 @@ PipeOp = R6Class("PipeOp",
),

active = list(
properties = function(rhs) {
if (!missing(rhs)) {
private$.properties = sort(assert_subset(rhs, mlr_reflections$pipeops$properties))
}
contingent_properties = private$.contingent_properties()
if (!length(contingent_properties)) {
return(private$.properties)
}
sort(c(private$.properties, contingent_properties))
},
id = function(val) {
if (!missing(val)) {
private$.id = val
Expand Down Expand Up @@ -415,6 +426,10 @@ PipeOp = R6Class("PipeOp",
),

private = list(
.contingent_properties = function(rhs) {
character(0)
},
.properties = NULL,
deep_clone = function(name, value) {
if (!is.null(private$.param_set_source)) {
private$.param_set = NULL # required to keep clone identical to original, otherwise tests get really ugly
Expand Down
5 changes: 4 additions & 1 deletion R/PipeOpImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,7 +148,6 @@ PipeOpImpute = R6Class("PipeOpImpute",

.train = function(inputs) {
intask = inputs[[1]]$clone(deep = TRUE)

affected_cols = (self$param_set$values$affect_columns %??% selector_all())(intask)
affected_cols = intersect(affected_cols, private$.select_cols(intask))

Expand Down Expand Up @@ -191,6 +190,10 @@ PipeOpImpute = R6Class("PipeOpImpute",

self$state$outtasklayout = copy(intask$feature_types)

if (!is.null(intask$test_task)) {
intask$test_task = private$.predict(list(intask$test_task))[[1L]]
}

list(intask)
},

Expand Down
8 changes: 7 additions & 1 deletion R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,13 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
),
private = list(
.learner = NULL,

.contingent_properties = function() {
if ("uses_test_task" %in% private$.learner$properties) {
"uses_test_task"
} else {
character(0)
}
},
.train = function(inputs) {
on.exit({private$.learner$state = NULL})
task = inputs[[1L]]
Expand Down
7 changes: 7 additions & 0 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -171,6 +171,13 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",
}
),
private = list(
.contingent_properties = function() {
if ("uses_test_task" %in% private$.learner$properties) {
"uses_test_task"
} else {
character(0)
}
},
.train_task = function(task) {
on.exit({private$.learner$state = NULL})

Expand Down
6 changes: 6 additions & 0 deletions R/PipeOpTaskPreproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -217,6 +217,12 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
self$state$outtasklayout = copy(intask$feature_types)
self$state$outtaskshell = intask$data(rows = intask$row_ids[0])

if (!is.null(intask$test_task)) {
# we call into .predict() and not .predict_task() to not put the burden
# of subsetting the features etc. on the PipeOp overwriting .predict_task
intask$test_task = private$.predict(list(intask$test_task))[[1L]]
}

if (do_subset) {
# FIXME: this fails if .train_task added a column with the same name
intask$col_roles$feature = union(intask$col_roles$feature, y = remove_cols)
Expand Down
1 change: 1 addition & 0 deletions R/zzz.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ register_mlr3 = function() {
c("abstract", "meta", "missings", "feature selection", "imbalanced data",
"data transform", "target transform", "ensemble", "robustify", "learner", "encode",
"multiplicity")))
x$pipeops$properties = c("uses_test_task")
}

.onLoad = function(libname, pkgname) { # nocov start
Expand Down
8 changes: 4 additions & 4 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/PipeOpEnsemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/PipeOpImpute.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 544ae50

Please sign in to comment.