Skip to content

Commit

Permalink
Merge pull request #408 from mlr-org/rmst_crankcompose_updates
Browse files Browse the repository at this point in the history
Updates in prediction type compositions
  • Loading branch information
bblodfon authored Aug 17, 2024
2 parents 5484214 + da0a858 commit 223f113
Show file tree
Hide file tree
Showing 47 changed files with 1,850 additions and 1,173 deletions.
8 changes: 3 additions & 5 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3proba
Title: Probabilistic Supervised Learning for 'mlr3'
Version: 0.6.6
Version: 0.6.7
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -64,11 +64,9 @@ Imports:
paradox (>= 1.0.0),
R6,
Rcpp (>= 1.0.4),
survival,
survivalmodels (>= 0.1.12)
survival
Suggests:
bujar,
cubature,
GGally,
knitr,
lgr,
Expand Down Expand Up @@ -145,6 +143,7 @@ Collate:
'PipeOpPredRegrSurv.R'
'PipeOpPredSurvRegr.R'
'PipeOpProbregrCompositor.R'
'PipeOpResponseCompositor.R'
'PipeOpSurvAvg.R'
'PipeOpTaskRegrSurv.R'
'PipeOpTaskSurvClassifDiscTime.R'
Expand Down Expand Up @@ -176,7 +175,6 @@ Collate:
'histogram.R'
'integrated_scores.R'
'mlr3proba-package.R'
'partition.R'
'pecs.R'
'pipelines.R'
'plot.R'
Expand Down
4 changes: 3 additions & 1 deletion NAMESPACE
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@ S3method(check_prediction_data,PredictionDataSurv)
S3method(filter_prediction_data,PredictionDataSurv)
S3method(is_missing_prediction_data,PredictionDataDens)
S3method(is_missing_prediction_data,PredictionDataSurv)
S3method(partition,TaskSurv)
S3method(pecs,PredictionSurv)
S3method(pecs,list)
S3method(plot,LearnerSurv)
Expand Down Expand Up @@ -77,6 +76,7 @@ export(PipeOpPredRegrSurv)
export(PipeOpPredSurvRegr)
export(PipeOpPredTransformer)
export(PipeOpProbregr)
export(PipeOpResponseCompositor)
export(PipeOpSurvAvg)
export(PipeOpTaskRegrSurv)
export(PipeOpTaskSurvClassifDiscTime)
Expand All @@ -95,7 +95,9 @@ export(as_prediction_surv)
export(as_task_dens)
export(as_task_surv)
export(assert_surv)
export(assert_surv_matrix)
export(breslow)
export(get_mortality)
export(pecs)
export(pipeline_survtoclassif_disctime)
export(pipeline_survtoregr)
Expand Down
8 changes: 8 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,11 @@
# mlr3proba 0.6.7

- Deprecate `crank` to `distr` composition in `distrcompose` pipeop (only from `lp` => `distr` works now)
- Add `get_mortality()` function (from `survivalmodels::surv_to_risk()`
- Add Rcpp function `assert_surv_matrix()`
- Update and simplify `crankcompose` pipeop and respective pipeline (no `response` is created anymore)
- Add `responsecompositor` pipeline with `rmst` and `median`

# mlr3proba 0.6.6

- Small fixes and refactoring to the discrete-time pipeops
Expand Down
2 changes: 1 addition & 1 deletion R/MeasureSurvCindex.R
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
#' library(mlr3)
#' task = tsk("rats")
#' learner = lrn("surv.coxph")
#' part = partition(task) # train/test split, stratified on `status` by default
#' part = partition(task) # train/test split
#' learner$train(task, part$train)
#' p = learner$predict(task, part$test)
#'
Expand Down
150 changes: 45 additions & 105 deletions R/PipeOpCrankCompositor.R
Original file line number Diff line number Diff line change
Expand Up @@ -16,57 +16,36 @@
#' ```
#'
#' @section Input and Output Channels:
#' [PipeOpCrankCompositor] has one input channel named "input", which takes
#' `NULL` during training and [PredictionSurv] during prediction.
#' [PipeOpCrankCompositor] has one input channel named `"input"`, which takes `NULL` during training and [PredictionSurv] during prediction.
#'
#' [PipeOpCrankCompositor] has one output channel named "output", producing `NULL` during training
#' and a [PredictionSurv] during prediction.
#' [PipeOpCrankCompositor] has one output channel named `"output"`, producing `NULL` during training and a [PredictionSurv] during prediction.
#'
#' The output during prediction is the [PredictionSurv] from the "pred" input but with the `crank`
#' predict type overwritten by the given estimation method.
#' The output during prediction is the [PredictionSurv] from the input but with the `crank` predict type overwritten by the given estimation method.
#'
#' @section State:
#' The `$state` is left empty (`list()`).
#'
#' @section Parameters:
#' * `method` :: `character(1)` \cr
#' Determines what method should be used to produce a continuous ranking from the distribution.
#' One of `sum_haz`, `median`, `mode`, or `mean` corresponding to the
#' respective functions in the predicted survival distribution. Note that
#' for models with a proportional hazards form, the ranking implied by
#' `mean` and `median` will be identical (but not the value of `crank`
#' itself). `sum_haz` (default) uses [survivalmodels::surv_to_risk()].
#' * `which` :: `numeric(1)`\cr
#' If `method = "mode"` then specifies which mode to use if multi-modal, default is the first.
#' * `response` :: `logical(1)`\cr
#' If `TRUE` then the `response` predict type is estimated with the same values as `crank`.
#' Currently only `mort` is supported, which is the sum of the cumulative hazard, also called *expected/ensemble mortality*, see Ishwaran et al. (2008).
#' For more details, see [get_mortality()].
#' * `overwrite` :: `logical(1)` \cr
#' If `FALSE` (default) then if the "pred" input already has a `crank`, the compositor only
#' composes a `response` type if `response = TRUE` and does not already exist. If `TRUE` then
#' both the `crank` and `response` are overwritten.
#'
#' @section Internals:
#' The `median`, `mode`, or `mean` will use analytical expressions if possible but if not they are
#' calculated using methods from [distr6]. `mean` requires \CRANpkg{cubature}.
#' If `FALSE` (default) and the prediction already has a `crank` prediction, then the compositor returns the input prediction unchanged.
#' If `TRUE`, then the `crank` will be overwritten.
#'
#' @seealso [pipeline_crankcompositor]
#' @family survival compositors
#' @examples
#' \dontrun{
#' if (requireNamespace("mlr3pipelines", quietly = TRUE)) {
#' library(mlr3)
#' library(mlr3pipelines)
#' task = tsk("rats")
#'
#' learn = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(method = "sum_haz"))
#' poc$predict(list(learn))[[1]]
#'
#' if (requireNamespace("cubature", quietly = TRUE)) {
#' learn = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(method = "sum_haz"))
#' poc$predict(list(learn))[[1]]
#' }
#' # change the crank prediction type of a Cox's model predictions
#' pred = lrn("surv.coxph")$train(task)$predict(task)
#' poc = po("crankcompose", param_vals = list(overwrite = TRUE))
#' poc$predict(list(pred))[[1L]]
#' }
#' }
#' @export
Expand All @@ -77,21 +56,18 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
#' Creates a new instance of this [R6][R6::R6Class] class.
initialize = function(id = "crankcompose", param_vals = list()) {
param_set = ps(
method = p_fct(default = "sum_haz", levels = c("sum_haz", "mean", "median", "mode"),
tags = "predict"),
which = p_int(1L, default = 1L, tags = "predict", depends = quote(method == "mode")),
response = p_lgl(default = FALSE, tags = "predict"),
method = p_fct(default = "mort", levels = c("mort"), tags = "predict"),
overwrite = p_lgl(default = FALSE, tags = "predict")
)
param_set$set_values(method = "sum_haz", response = FALSE, overwrite = FALSE)
param_set$set_values(method = "mort", overwrite = FALSE)

super$initialize(
id = id,
param_set = param_set,
param_vals = param_vals,
input = data.table(name = "input", train = "NULL", predict = "PredictionSurv"),
output = data.table(name = "output", train = "NULL", predict = "PredictionSurv"),
packages = c("mlr3proba", "distr6")
packages = c("mlr3proba")
)
}
),
Expand All @@ -103,83 +79,47 @@ PipeOpCrankCompositor = R6Class("PipeOpCrankCompositor",
},

.predict = function(inputs) {

inpred = inputs[[1L]]

response = self$param_set$values$response
b_response = !anyMissing(inpred$response)
if (!length(response)) response = FALSE

pred = inputs[[1L]]
overwrite = self$param_set$values$overwrite
if (!length(overwrite)) overwrite = FALSE
# it's impossible for a learner not to predict crank in mlr3proba,
# but let's check either way:
has_crank = !all(is.na(pred$crank))

# if crank and response already exist and not overwriting then return prediction
if (!overwrite && (!response || (response && b_response))) {
return(list(inpred))
if (!overwrite & has_crank) {
# return prediction as is
return(list(pred))
} else {
assert("distr" %in% inpred$predict_types)
method = self$param_set$values$method
if (length(method) == 0L) method = "sum_haz"
if (method == "sum_haz") {
if (inherits(inpred$data$distr, "matrix") ||
!requireNamespace("survivalmodels", quietly = TRUE)) {
comp = survivalmodels::surv_to_risk(inpred$data$distr)
} else {
comp = as.numeric(
colSums(inpred$distr$cumHazard(sort(unique(inpred$truth[, 1]))))
)
}
} else if (method == "mean") {
comp = try(inpred$distr$mean(), silent = TRUE)
if (inherits(comp, "try-error")) {
requireNamespace("cubature")
comp = try(inpred$distr$mean(cubature = TRUE), silent = TRUE)
}
if (inherits(comp, "try-error")) {
comp = numeric(length(inpred$crank))
}
} else {
comp = switch(method,
median = inpred$distr$median(),
mode = inpred$distr$mode(self$param_set$values$which))
}
# compose crank from distr prediction
assert("distr" %in% pred$predict_types)

comp = as.numeric(comp)

# if crank exists and not overwriting then return predicted crank, otherwise compose
if (!overwrite) {
crank = inpred$crank
# get survival matrix
if (inherits(pred$data$distr, "array")) {
surv = pred$data$distr
if (length(dim(surv)) == 3L) {
# survival 3d array, extract median
surv = .ext_surv_mat(arr = surv, which.curve = 0.5)
}
} else {
crank = -comp
# missing imputed with median
crank[is.na(crank)] = stats::median(crank[!is.na(crank)])
crank[crank == Inf] = 1e3
crank[crank == -Inf] = -1e3
stop("Distribution prediction does not have a survival matrix or array
in the $data$distr slot")
}

# i) not overwriting or requesting response, and already predicted
if (b_response && (!overwrite || !response)) {
response = inpred$response
# ii) not requesting response and doesn't exist
} else if (!response) {
response = NULL
# iii) requesting response and happy to overwrite
# iv) requesting response and doesn't exist
} else {
response = comp
response[is.na(response)] = 0
response[response == Inf | response == -Inf] = 0
method = self$param_set$values$method
if (method == "mort") {
crank = get_mortality(surv)
}

if (!anyMissing(inpred$lp)) {
lp = inpred$lp
} else {
lp = NULL
}
# update only `crank`
p = PredictionSurv$new(
row_ids = pred$row_ids,
truth = pred$truth,
crank = crank,
distr = pred$distr,
lp = pred$lp,
response = pred$response
)

return(list(PredictionSurv$new(
row_ids = inpred$row_ids, truth = inpred$truth, crank = crank,
distr = inpred$distr, lp = lp, response = response)))
return(list(p))
}
}
)
Expand Down
Loading

0 comments on commit 223f113

Please sign in to comment.