Skip to content

Commit

Permalink
new param no_collapse_above_absolute
Browse files Browse the repository at this point in the history
  • Loading branch information
advieser committed Oct 16, 2024
1 parent c823f18 commit 8238a06
Showing 1 changed file with 6 additions and 2 deletions.
8 changes: 6 additions & 2 deletions R/PipeOpCollapseFactors.R
Original file line number Diff line number Diff line change
Expand Up @@ -61,9 +61,10 @@ PipeOpCollapseFactors = R6Class("PipeOpCollapseFactors",
initialize = function(id = "collapsefactors", param_vals = list()) {
ps = ps(
no_collapse_above_prevalence = p_dbl(0, 1, tags = c("train", "predict")),
no_collapse_above_absolute = p_int(1, special_vals = list(Inf), tags = c("train", "predict")),
target_level_count = p_int(2, tags = c("train", "predict"))
)
ps$values = list(no_collapse_above_prevalence = 1, target_level_count = 2)
ps$values = list(no_collapse_above_prevalence = 1, no_collapse_above_count = Inf, target_level_count = 2)
super$initialize(id, param_set = ps, param_vals = param_vals, feature_types = c("factor", "ordered"))
}
),
Expand All @@ -74,6 +75,7 @@ PipeOpCollapseFactors = R6Class("PipeOpCollapseFactors",
dt = task$data(cols = private$.select_cols(task))

keep_fraction = self$param_set$values$no_collapse_above_prevalence
keep_absolute = self$param_set$values$no_collapse_above_absolute
target_count = self$param_set$values$target_level_count

collapse_map = sapply(dt, function(d) {
Expand All @@ -86,8 +88,10 @@ PipeOpCollapseFactors = R6Class("PipeOpCollapseFactors",
dtable = table(d)
fractions = sort(dtable, decreasing = TRUE) / sum(!is.na(d))
keep_fraction = names(fractions)[fractions >= keep_fraction]
# TODO: test this
keep_absolute = names(fractions)[fractions >= keep_absolute]
keep_count = names(fractions)[seq_len(target_count)] # at this point we know there are more levels than target_count
keep = union(keep_fraction, keep_count)
keep = union(keep_fraction, setdiff(keep_count, keep_absolute))
dont_keep = setdiff(levels(d), keep)
if (is.ordered(d)) {
cmap = stats::setNames(as.list(levels(d)), levels(d))
Expand Down

0 comments on commit 8238a06

Please sign in to comment.