Skip to content

Commit

Permalink
Merge pull request #113 from mlr-org/add_param6
Browse files Browse the repository at this point in the history
fix distr6 learners
  • Loading branch information
RaphaelS1 authored Sep 12, 2021
2 parents b9aa81b + c5d16ac commit ce33c22
Show file tree
Hide file tree
Showing 6 changed files with 27 additions and 39 deletions.
1 change: 1 addition & 0 deletions .lintr
Original file line number Diff line number Diff line change
Expand Up @@ -5,5 +5,6 @@ linters: with_defaults(
object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names
cyclocomp_linter = NULL, # do not check function complexity
commented_code_linter = NULL, # allow code in comments
todo_comment_linter = NULL, # allow todo in comments
line_length_linter = line_length_linter(100)
)
5 changes: 3 additions & 2 deletions DESCRIPTION
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
Package: mlr3extralearners
Title: Extra Learners For mlr3
Version: 0.5.5
Version: 0.5.6
Authors@R:
c(person(given = "Raphael",
family = "Sonabend",
Expand Down Expand Up @@ -80,6 +80,7 @@ Suggests:
nnet,
np,
obliqueRSF,
param6,
partykit,
penalized,
pendensity,
Expand All @@ -96,7 +97,7 @@ Suggests:
sm,
stats,
survival,
survivalmodels (>= 0.1.4),
survivalmodels (>= 0.1.9),
survivalsvm,
tensorflow (>= 2.0.0),
testthat,
Expand Down
4 changes: 4 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
# mlr3extralearners 0.5.6

* Fix learners requiring distr6. distr6 1.6.0 now forced and param6 added to suggests

# mlr3extralearners 0.5.5

* Bugfix `regr.gausspr`
Expand Down
54 changes: 18 additions & 36 deletions R/learner_flexsurv_surv_flexible.R
Original file line number Diff line number Diff line change
Expand Up @@ -155,51 +155,35 @@ predict_flexsurvreg <- function(object, task, ...) {
# parameters above.
pdf = function(x) {} # nolint
body(pdf) = substitute({
fn = func
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(x = x), args))
do.call(func, c(list(x = x), self$parameters()$values))
}, list(func = object$dfns$d))

cdf = function(x) {} # nolint
body(cdf) = substitute({
fn = func
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(q = x), args))
do.call(func, c(list(q = x), self$parameters()$values))
}, list(func = object$dfns$p))

quantile = function(p) {} # nolint
body(quantile) = substitute({
fn = func
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(p = p), args))
do.call(func, c(list(p = p), self$parameters()$values))
}, list(func = object$dfns$q))

rand = function(n) {} # nolint
body(rand) = substitute({
fn = func
args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value
names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1]))
do.call(fn, c(list(n = n), args))
do.call(func, c(list(n = n), self$parameters()$values))
}, list(func = object$dfns$r))

# The parameter set combines the auxiliary parameters with the fitted gamma coefficients.
# Whilst the
# user can set these after fitting, this is generally ill-advised.
parameters = distr6::ParameterSet$new(
id = c(names(args), object$dlist$pars),
value = c(list(
numeric(length(object$knots)),
"hazard", "log"), rep(list(0), length(object$dlist$pars))),
settable = rep(TRUE, length(args) + length(object$dlist$pars)),
support = c(
list(set6::Reals$new()^length(object$knots)),
set6::Set$new("hazard", "odds", "normal"),
set6::Set$new("log", "identity"),
rep(list(set6::Reals$new()), length(object$dlist$pars)))
)
# Whilst the user can set these after fitting, this is generally ill-advised.
parameters = param6::ParameterSet$new(c(list(
param6::prm(
"knots", set6::Reals$new()^length(object$knots),
numeric(length(object$knots))
),
param6::prm("scale", set6::Set$new("hazard", "odds", "normal"), "hazard"),
param6::prm("timescale", set6::Set$new("log", "identity"), "log")),
lapply(object$dlist$pars, function(x) param6::prm(x, "reals", 0))
))

pars = data.table::data.table(t(pars))
pargs = data.table::data.table(matrix(args, ncol = ncol(pars), nrow = length(args)))
Expand All @@ -217,18 +201,16 @@ predict_flexsurvreg <- function(object, task, ...) {
pdf = pdf, cdf = cdf, quantile = quantile, rand = rand
)

## FIXME - This is bad and needs speeding up
distlist = lapply(pars, function(x) {
x = as.list(x)
names(x) = c(object$dlist$pars, names(args))
yparams = parameters$clone(deep = TRUE)
ind = match(yparams$.__enclos_env__$private$.parameters$id, names(x))
yparams$.__enclos_env__$private$.parameters$value = x[ind]
yparams$values = setNames(as.list(x), c(object$dlist$pars, names(args)))

do.call(distr6::Distribution$new, c(list(parameters = yparams), shared_params))
})

distr = distr6::VectorDistribution$new(distlist,
decorators = c("CoreStatistics", "ExoticStatistics"))
distr = distr6::VectorDistribution$new(
distlist, decorators = c("CoreStatistics", "ExoticStatistics"))

return(list(distr = distr, lp = lp))
}
Expand Down
2 changes: 1 addition & 1 deletion R/learner_survival_surv_parametric.R
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ LearnerSurvParametric = R6Class("LearnerSurvParametric", inherit = mlr3proba::Le
},
cdf = function() {
},
parameters = distr6::ParameterSet$new()
parameters = param6::pset()
))

params = rep(params, length(lp))
Expand Down
Binary file modified R/sysdata.rda
Binary file not shown.

0 comments on commit ce33c22

Please sign in to comment.