Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

started to add support for test rows #754

Closed
wants to merge 4 commits into from
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions DESCRIPTION
Original file line number Diff line number Diff line change
Expand Up @@ -87,20 +87,22 @@ Suggests:
methods,
vtreat,
future
Remotes:
mlr-org/mlr3@feat/train-predict
ByteCompile: true
Encoding: UTF-8
Config/testthat/edition: 3
Config/testthat/parallel: true
NeedsCompilation: no
Roxygen: list(markdown = TRUE, r6 = FALSE)
RoxygenNote: 7.2.3
RoxygenNote: 7.3.1
VignetteBuilder: knitr
Collate:
'Graph.R'
'GraphLearner.R'
'mlr_pipeops.R'
'multiplicity.R'
'utils.R'
'GraphLearner.R'
'mlr_pipeops.R'
'PipeOp.R'
'PipeOpEnsemble.R'
'LearnerAvg.R'
Expand Down
1 change: 1 addition & 0 deletions NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ import(mlr3)
import(mlr3misc)
import(paradox)
importFrom(R6,R6Class)
importFrom(data.table,as.data.table)
importFrom(digest,digest)
importFrom(stats,setNames)
importFrom(utils,bibentry)
Expand Down
15 changes: 14 additions & 1 deletion R/GraphLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,7 @@
#' so `TRUE` (default) is recommended. In particular, note that the `$state` of `$graph` is set to `NULL` by reference on
#' construction of `GraphLearner`, during `$train()`, and during `$predict()` when `clone_graph` is `FALSE`.
#'
#' @include utils.R
#' @section Fields:
#' Fields inherited from [`PipeOp`], as well as:
#' * `graph` :: [`Graph`]\cr
Expand Down Expand Up @@ -102,7 +103,7 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
feature_types = mlr_reflections$task_feature_types,
predict_types = names(mlr_reflections$learner_predict_types[[task_type]]),
packages = graph$packages,
properties = mlr_reflections$learner_properties[[task_type]],
properties = setdiff(mlr_reflections$learner_properties[[task_type]], "uses_test_rows"),
man = "mlr3pipelines::GraphLearner"
)

Expand Down Expand Up @@ -173,6 +174,9 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,
}
),
private = list(
.contingent_properties = function() {
if (some(self$graph$pipeops, uses_test_rows)) "uses_test_rows" else character(0)
},
.graph = NULL,
deep_clone = function(name, value) {
# FIXME this repairs the mlr3::Learner deep_clone() method which is broken.
Expand All @@ -187,6 +191,15 @@ GraphLearner = R6Class("GraphLearner", inherit = Learner,

.train = function(task) {
on.exit({self$graph$state = NULL})

if (!("uses_test_rows" %in% self$properties)) {
# PipeOps like PipeOpTaskPreproc will always make predictions for the test rows during train
# To avoid doing this unless we need it, we remove those row roles temporarily if no pipeop needs them
on.exit({task$row_roles$test = prev_test}, add = TRUE) # nolint
prev_test = task$row_roles$test
task$row_roles$test = integer(0)
}

self$graph$train(task)
state = self$graph$state
state
Expand Down
4 changes: 4 additions & 0 deletions R/PipeOp.R
Original file line number Diff line number Diff line change
Expand Up @@ -335,6 +335,10 @@ PipeOp = R6Class("PipeOp",
),

active = list(
properties = function(rhs) {
assert_ro_binding(rhs)
character(0)
},
id = function(val) {
if (!missing(val)) {
private$.id = val
Expand Down
22 changes: 22 additions & 0 deletions R/PipeOpImpute.R
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,7 @@ PipeOpImpute = R6Class("PipeOpImpute",

.train = function(inputs) {
intask = inputs[[1]]$clone(deep = TRUE)
task = intask$clone(deep = TRUE)

affected_cols = (self$param_set$values$affect_columns %??% selector_all())(intask)
affected_cols = intersect(affected_cols, private$.select_cols(intask))
Expand Down Expand Up @@ -189,8 +190,29 @@ PipeOpImpute = R6Class("PipeOpImpute",

intask$select(setdiff(intask$feature_names, colnames(imputanda)))$cbind(imputanda)

test_rows_exist = length(task$row_roles$test) > 0

self$state$outtasklayout = copy(intask$feature_types)

if (test_rows_exist) {
# FIXME: This is copy pasta from PipeOpTaskPreproc
predict_task = task$clone(deep = TRUE)
predict_task$row_roles$use = task$row_roles$test
predict_task = private$.predict(list(predict_task))

test_cols = unique(unlist(intask$col_roles))

test_data = predict_task[[1]]$data(cols = test_cols)
# this creates new row_ids for the test data
prev_use = intask$row_roles$use
intask$rbind(test_data)
intask$row_roles$test = setdiff(intask$row_roles$use, prev_use)
intask$row_roles$use = prev_use
}




list(intask)
},

Expand Down
4 changes: 4 additions & 0 deletions R/PipeOpLearner.R
Original file line number Diff line number Diff line change
Expand Up @@ -97,6 +97,10 @@ PipeOpLearner = R6Class("PipeOpLearner", inherit = PipeOp,
}
),
active = list(
properties = function(rhs) {
assert_ro_binding(rhs)
if ("uses_test_rows" %in% private$.learner$properties) "uses_test_rows" else character(0)
},
id = function(val) {
if (!missing(val)) {
private$.id = val
Expand Down
4 changes: 4 additions & 0 deletions R/PipeOpLearnerCV.R
Original file line number Diff line number Diff line change
Expand Up @@ -142,6 +142,10 @@ PipeOpLearnerCV = R6Class("PipeOpLearnerCV",

),
active = list(
properties = function(rhs) {
assert_ro_binding(rhs)
if ("uses_test_rows" %in% private$.learner$properties) "uses_test_rows" else character(0)
},
learner = function(val) {
if (!missing(val)) {
if (!identical(val, private$.learner)) {
Expand Down
40 changes: 39 additions & 1 deletion R/PipeOpTaskPreproc.R
Original file line number Diff line number Diff line change
Expand Up @@ -199,8 +199,10 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",

.train = function(inputs) {
intask = inputs[[1]]$clone(deep = TRUE)

do_subset = !is.null(self$param_set$values$affect_columns)
affected_cols = intask$feature_names
remove_cols = NULL
if (do_subset) {
affected_cols = self$param_set$values$affect_columns(intask)
assert_subset(affected_cols, intask$feature_names, empty.ok = TRUE)
Expand All @@ -210,16 +212,52 @@ PipeOpTaskPreproc = R6Class("PipeOpTaskPreproc",
}
intasklayout = copy(intask$feature_types)

test_rows_exist = length(intask$row_roles$test) > 0
if (test_rows_exist) {
predict_task = intask$clone(deep = TRUE)
}

intask = private$.train_task(intask)

self$state$affected_cols = affected_cols
self$state$intasklayout = intasklayout
self$state$outtasklayout = copy(intask$feature_types)
self$state$outtaskshell = intask$data(rows = intask$row_ids[0])


if (test_rows_exist) {
predict_task$row_roles$use = predict_task$row_roles$test
predict_task = private$.predict_task(predict_task)

# FIXME: These are all the columns that a learner might use.
# To be on the save side, we could also add all available columns
test_cols = unique(c(remove_cols, unlist(predict_task$col_roles)))

test_data = predict_task$data(cols = test_cols)

# in some cases (such as class weights, different columns are added during train and predict),
# we fill those values with NAs
missing_cols = setdiff(unlist(intask$col_roles), colnames(test_data))
if (length(missing_cols)) {
missing_data = intask$data(intask$row_roles$use[1L], missing_cols)
missing_data[1, ] = NA
missing_data = missing_data[1, lapply(get(".SD"), function(col) rep(col, nrow(test_data)))]
test_data = cbind(missing_data, test_data)
}

# this creates new row_ids for the test data
prev_use = intask$row_roles$use
intask$rbind(test_data)
intask$row_roles$test = setdiff(intask$row_roles$use, prev_use)
intask$row_roles$use = prev_use
}
if (do_subset) {
# FIXME: this fails if .train_task added a column with the same name
intask$col_roles$feature = union(intask$col_roles$feature, y = remove_cols)
new_features = union(intask$col_roles$feature, y = remove_cols)
intask$col_roles$feature = new_features
if (test_rows_exist) {
intask$col_roles$feature = new_features
}
}
list(intask)
},
Expand Down
4 changes: 4 additions & 0 deletions R/utils.R
Original file line number Diff line number Diff line change
Expand Up @@ -134,3 +134,7 @@ dictionary_sugar_inc_mget = function(dict, .keys, ...) {
names(objs) = map_chr(objs, "id")
objs
}

uses_test_rows = function(pipeop) {
"uses_test_rows" %in% pipeop$properties
}
8 changes: 4 additions & 4 deletions man/Graph.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

16 changes: 8 additions & 8 deletions man/PipeOp.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/PipeOpEnsemble.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

12 changes: 6 additions & 6 deletions man/PipeOpImpute.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading
Loading