From 271668c346c372b16c3be7971c4deadc9656ca75 Mon Sep 17 00:00:00 2001 From: RaphaelS1 Date: Fri, 10 Sep 2021 16:08:56 +0100 Subject: [PATCH 1/4] fix distr6 learners --- DESCRIPTION | 3 +- NEWS.md | 4 +++ R/learner_flexsurv_surv_flexible.R | 54 ++++++++++------------------ R/learner_survival_surv_parametric.R | 2 +- 4 files changed, 25 insertions(+), 38 deletions(-) diff --git a/DESCRIPTION b/DESCRIPTION index 5df5beb3d..8945a6ab7 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -1,6 +1,6 @@ Package: mlr3extralearners Title: Extra Learners For mlr3 -Version: 0.5.5 +Version: 0.5.6 Authors@R: c(person(given = "Raphael", family = "Sonabend", @@ -80,6 +80,7 @@ Suggests: nnet, np, obliqueRSF, + param6, partykit, penalized, pendensity, diff --git a/NEWS.md b/NEWS.md index 5fdc0e019..70423c1f6 100644 --- a/NEWS.md +++ b/NEWS.md @@ -1,3 +1,7 @@ +# mlr3extralearners 0.5.6 + +* Fix learners requiring distr6. distr6 1.6.0 now forced and param6 added to suggests + # mlr3extralearners 0.5.5 * Bugfix `regr.gausspr` diff --git a/R/learner_flexsurv_surv_flexible.R b/R/learner_flexsurv_surv_flexible.R index 71ac8e081..3a3756c17 100644 --- a/R/learner_flexsurv_surv_flexible.R +++ b/R/learner_flexsurv_surv_flexible.R @@ -155,51 +155,35 @@ predict_flexsurvreg <- function(object, task, ...) { # parameters above. pdf = function(x) {} # nolint body(pdf) = substitute({ - fn = func - args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value - names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1])) - do.call(fn, c(list(x = x), args)) + do.call(func, c(list(x = x), self$parameters()$values)) }, list(func = object$dfns$d)) cdf = function(x) {} # nolint body(cdf) = substitute({ - fn = func - args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value - names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1])) - do.call(fn, c(list(q = x), args)) + do.call(func, c(list(q = x), self$parameters()$values)) }, list(func = object$dfns$p)) quantile = function(p) {} # nolint body(quantile) = substitute({ - fn = func - args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value - names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1])) - do.call(fn, c(list(p = p), args)) + do.call(func, c(list(p = p), self$parameters()$values)) }, list(func = object$dfns$q)) rand = function(n) {} # nolint body(rand) = substitute({ - fn = func - args = as.list(subset(data.table::as.data.table(self$parameters()), select = "value"))$value - names(args) = unname(unlist(data.table::as.data.table(self$parameters())[, 1])) - do.call(fn, c(list(n = n), args)) + do.call(func, c(list(n = n), self$parameters()$values)) }, list(func = object$dfns$r)) # The parameter set combines the auxiliary parameters with the fitted gamma coefficients. - # Whilst the - # user can set these after fitting, this is generally ill-advised. - parameters = distr6::ParameterSet$new( - id = c(names(args), object$dlist$pars), - value = c(list( - numeric(length(object$knots)), - "hazard", "log"), rep(list(0), length(object$dlist$pars))), - settable = rep(TRUE, length(args) + length(object$dlist$pars)), - support = c( - list(set6::Reals$new()^length(object$knots)), - set6::Set$new("hazard", "odds", "normal"), - set6::Set$new("log", "identity"), - rep(list(set6::Reals$new()), length(object$dlist$pars))) - ) + # Whilst the user can set these after fitting, this is generally ill-advised. + parameters = param6::ParameterSet$new(c(list( + param6::prm( + "knots", set6::Reals$new()^length(object$knots), + numeric(length(object$knots)) + ), + param6::prm("scale", set6::Set$new("hazard", "odds", "normal"), "hazard"), + param6::prm("timescale", set6::Set$new("log", "identity"), "log")), + lapply(object$dlist$pars, function(x) param6::prm(x, "reals", 0)) + )) pars = data.table::data.table(t(pars)) pargs = data.table::data.table(matrix(args, ncol = ncol(pars), nrow = length(args))) @@ -217,18 +201,16 @@ predict_flexsurvreg <- function(object, task, ...) { pdf = pdf, cdf = cdf, quantile = quantile, rand = rand ) + ## FIXME - This is bad and needs speeding up distlist = lapply(pars, function(x) { - x = as.list(x) - names(x) = c(object$dlist$pars, names(args)) yparams = parameters$clone(deep = TRUE) - ind = match(yparams$.__enclos_env__$private$.parameters$id, names(x)) - yparams$.__enclos_env__$private$.parameters$value = x[ind] + yparams$values = setNames(as.list(x), c(object$dlist$pars, names(args))) do.call(distr6::Distribution$new, c(list(parameters = yparams), shared_params)) }) - distr = distr6::VectorDistribution$new(distlist, - decorators = c("CoreStatistics", "ExoticStatistics")) + distr = distr6::VectorDistribution$new( + distlist, decorators = c("CoreStatistics", "ExoticStatistics")) return(list(distr = distr, lp = lp)) } diff --git a/R/learner_survival_surv_parametric.R b/R/learner_survival_surv_parametric.R index c9d7e6992..28e170da1 100644 --- a/R/learner_survival_surv_parametric.R +++ b/R/learner_survival_surv_parametric.R @@ -216,7 +216,7 @@ LearnerSurvParametric = R6Class("LearnerSurvParametric", inherit = mlr3proba::Le }, cdf = function() { }, - parameters = distr6::ParameterSet$new() + parameters = param6::pset() )) params = rep(params, length(lp)) From 47a2357cb0ee3392bef48cfc1f490e92391a91f8 Mon Sep 17 00:00:00 2001 From: RaphaelS1 Date: Fri, 10 Sep 2021 16:39:17 +0100 Subject: [PATCH 2/4] lint --- .lintr | 1 + 1 file changed, 1 insertion(+) diff --git a/.lintr b/.lintr index 04595504c..fb0f1d0e5 100644 --- a/.lintr +++ b/.lintr @@ -5,5 +5,6 @@ linters: with_defaults( object_name_linter = object_name_linter(c("snake_case", "CamelCase")), # only allow snake case and camel case object names cyclocomp_linter = NULL, # do not check function complexity commented_code_linter = NULL, # allow code in comments + todo_comment_linter = NULL, # allow todo in comments line_length_linter = line_length_linter(100) ) From 5f10babab40953fa643f54a4df3e5ddb416d83c7 Mon Sep 17 00:00:00 2001 From: Raphael Sonabend Date: Sun, 12 Sep 2021 12:04:51 +0100 Subject: [PATCH 3/4] Update DESCRIPTION --- DESCRIPTION | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/DESCRIPTION b/DESCRIPTION index 8945a6ab7..6f08a3851 100644 --- a/DESCRIPTION +++ b/DESCRIPTION @@ -97,7 +97,7 @@ Suggests: sm, stats, survival, - survivalmodels (>= 0.1.4), + survivalmodels (>= 0.1.9), survivalsvm, tensorflow (>= 2.0.0), testthat, From c5d16acf3dde4d71654a1b0ae5227c1c41cb752e Mon Sep 17 00:00:00 2001 From: github-actions <41898282+github-actions[bot]@users.noreply.github.com> Date: Sun, 12 Sep 2021 11:41:12 +0000 Subject: [PATCH 4/4] Update learner table --- R/sysdata.rda | Bin 2735 -> 2746 bytes 1 file changed, 0 insertions(+), 0 deletions(-) diff --git a/R/sysdata.rda b/R/sysdata.rda index 32179ebb8747829bca499d775bbc2a80f6adb6d8..d3d2a244273087f8f05d937dc7feb7068bdec870 100644 GIT binary patch literal 2746 zcmV;r3PtroT4*^jL0KkKS=A%(hX5ow|G@wM{Qv*~f51II+`zx@|M)@x00H0*KRw0; zEzsks*~gAPyTI?Uq1}V3_2_ibS{=6)9ewPo5}J((3V51r02))#s0K|90j5LL1Jn%! zl4uAeJv~h|JflX9JfIo?0MKO7pwJ<-l%N0r000000000GsT9IAnr#uOqcqYFNrGTT z0x>;6446eC0-BjWDs4?XrcWduro|qosMFNK8UO~Kpb;b_$*F+?42+chN$pHiLlK|= z05tVDiUpv^2(k$Xss_5M@3N@iR>+M2nL>YX@uSx$uSv6((J&dU%h5?jOv}nrgsRTR6U{jiT;}A79@;u?PD^0Gu`jv! zvB-71AQv8svp~$u48r9whMiST=%j(l_1 zfE$D)5<#A=cgca+;*3NujkYqbJW7zPv2K7(=^4IMu~hAaM%PCWj^b!XI<8||1{Ua ziW_^GU$JmfnH@cyJn@1jfkj+26galx0C*zIh1O(J2V}mmbCMDeTNs*TbOY0nKrAWf z&1H*9^v6+F0cOUOb+L!&)?( zmrD5&<}tVdszh`W(#;S|%|$z=g8(7Pw?`LW7@8>DoK;dWF^da;g++zrCc+!Otb3-W z@(YJfdV2LuToEichQXN>d=&LH>%X}xuYx44_b<2I|MO>;eP}$l`FPecI>*4Zodo+ zRs;j$$(bdI!A#hR8rOC(G1s0p!AN)$y7O>?yJou>b42cO6TKSj+py!$E<0yWuH#Sx zt{~JG#s#ISmbPSDd?7_SZ>@IKXS3d-q8`CT5r87DQ2|UWpr%r&qKEl@atH@dh^ZoC zS{ekT3Ir%9qKJYaD4;-S1tO{Jlro)Cje80#P$cY0fB~yjRcsE^x^v4S2WX^J_&ZLH6u$j_oxTNawKrqB5NcIr6gbHFN zXF3>js31fD^Dr=54LI%`xW^By!w^zY7zh9{sR1V-I8Eq~Z*YbUWc$s?H$#{aP~LDC z?%nkAH<%BKB^kfCxQO-_J0M6&EHr>o0LJ5j0G)>+|AmtQgv}GEG&%#kKZ3|0Vg!mF zXESm4pEkTNpQc64=vSnXo5 z5YH_Lgo742flLY#b<$DT@>jgBQNBPF<%vejY}lnoDma-N!sTXa>@yAp&9B)cVHym? zM?l_tQ9qfXg)~HdACRJ}l9m$z5TQka0g`}x#pzPOOfNz8^mdKFbn6VtPPCCEf)N#< zMu`Hx$C%Kl@gf19NzSIh5xAAtG#rjx^GVr01qonKa%l)A1YnY&kzoo<&K=+veW#iN zLy^f1U~3NA0VF{^h#~3A%i?&pAv)CbSA*`U)tO@BIb3=TUcNpghtO>_HDmY)cb8!# zDxuhY{O7rubu20+U+`0nOtAMFKX?d2T)R~t(E3P<(gFrR-WXE|lz^S|d4{Ll z?;3l>bPU6&b9X{2H~5iJ;%+t%`XFAg#=sh|YRwF~Jut`BQA$9CD&qSdzA6~(>$<44|6j@6Z zl^DKOjfE&HDzXS&$;INC@PXZO@DWs$RWu6mslm4+TByeW47Zl(_e|#H5X?QJZU7gA zAdDaZ7>J%Ix`nz+QA$;+AWIPlL8`T+sK_$XrKzxlEcGRTZm4@jef)1FpJLah3uef= zMrC6ZAt=z4(+Fq6I-1Da{1TxEq5-g_t^D1c-?gCn=>6xeyG3az!{$kkupbhl-Nw A-~a#s literal 2735 zcmV;g3Q+YzT4*^jL0KkKSrFtzk^m$9|G@wM{Qv*~f51II+`zx@|MWxv00H0*AA9dT z;aXnn((}~zc=+QTeUGi}*It3{!;mxvoz23M*r-Vn8dGVYWK-Iwff+}n8UsKy(De-u zNIgM{X{LxOdYe-b`lC$)^rJ!QXaF?OWC7&?sqI9h4FCg1fB*mh0000&l1)uh$*G>E zrqn%64H^Jw02%-Q07RmmiRzjfYI#E;h-fq!4FJ#sKmZXWM8XqIG!VdsrqGZ*O&Vwb z01Y@I1)#{FvIz*P2D+;6vMAt-q(=ZwYz0FQB~^qaP$nv>>7hYT1DXnudA@Xg@oDPj zsmgNx6?sVYyw9X(&uNSTR_EIfH%R$JKzgP_TTrl=f5_Z$q?-h%pG&UECjVRv+w7aJfZrmN@XsUSKx} zNF;+jdhe3~*x^PI3*n`VYli})D;zDgLwZJUl`K^{LZfS=gpRXABT~rjw6+(^FwTOr zrKU?zShy6}jJ<~nZ-rA>P+%A4`||luXL_jY*?P0n^zsf+l`)g?+5iZdfDOzc0f~^3 zIMqaBYbt`F@q!>1>yKsoy@2*xFd?KNd^!Uu#f61S3BysFT=T1j2cmeDxIWZ8kY%gA zru@lqBAA^WT^w-0iQ-XL48;x&xBwm~vk`TelmXQ*tQ@3-gcimVJWZhVVh{|+9$V~j zPkBSCYalGKh+~=!w;_uWaU%2rNhAV-grbUyN`i(h$6({;Z?(&{)9&y`7_LegHIKgc zL>})VluILINQOWW;E-DztXNofqV0i6QUpN+L=aDQ4(L%3k&p=q;)H_hCcgF43`Jt% zyomKmZ~)^1JCf4P5KPTQJm%&AhnL;otG}5vQNnVnk%^44a3dv#NePJ!t~+(+V`PZg zzMJ&(CT7G+j!0s_$}e?*>C;|qHMR8s=7?tP6*9pOuEDWc5*^$*l*fNF`;u z80lqp@Xs7!8LObk8Uk1hr-i9Zr0}LP8-VE>;BVC?6Eh)|6-$_*tu(MhfS_ajea8r~%aw zY76B8($z~_F)h9jqMUcux>aSf;E19f>LMcm1zcbPl(9igLZFHt&Gtfo9pWOOiHT@w z0HCB2iYSO85{gg;g%T>FB^263T#)QJK=FjtrbIEa6Qno>$YLBs0Dy>#z<_OlfMxGU zQpIY(ZMNW&35k}q>Pc3H;;p~{8nZJuAr?ESo|k%DiybU0ic*Os*mNm~0?-@I1WvE3 zpMc_4c-M5Ej+N${=hmcbK*a> zlcY;ow4AkemMIorn^QGIiEHiUFP!J0Bj!^FLGhYCgT>r99v`e3i8qg&&Uk{{jn2d0 zO^+C1pabAR4sh@F^d5$1FByzTQ5X`doKO;U2P2{=cGx3<(O==xgLF876%FSwzi#i3 zIlRC=Xq04r(&8i5Y<0m!S{wi{fZ@`1&;;u_5A@_puIleF$K^&9~5F=-x@g#p6Kc5sS`+-VITPhl}c9k8r#e+v{B2Iz`fEtKK5ECj)JpH01Ea2)zWiHY7yN`xR*ff^tR@SX#8 zN1zZ3**>KnIDX)49||BP1Aw$ovT(_v#35X1}8y9F$PNs-Omt=|LD0U^$Fp{xyI z08X+Xo#-Lz$KUfjn-HBUc2|q-smmh8#&g_un!P+ci4QYFGgdz09mUv53aEJ@>oW8; zH6$Yt7Bp&ZK?n#Fnw@mq$D|H|N|1&|f&ro~PFu`CXK|Qxk)&YonoZS6nsf*=n9icG zv`7NSAp`@>+5V5r01d z9O{R^((!7i6F>ub^`jN=G&a5nA87&i$x?O;-X88$dyczaN$2P4PmmJ+VGkMEl&CyQ z>SADBQ`g-B>Rq;gm#r#__QvNdS<; zDiIb6;GJ)Y2u;p@h9<*X`EWUe#zi8Oq7+lO*r>kHKq$|03BEDON=`<-r&IGmHZ2N; zfXvMbK-B}tirn`ohp|s2Pk{v?z4QSHB=6#={^Qm{5)d}%gUM9kpVV*zn&9xAbv*-r zRwrF(ha?Xbe2~6*16UJg^Wq-^w8-`-a!(BRSwt>_1Hn(kVhVhdaeWSBPEDTJ@Qoow zI2n}%XSdjQJwyXwcU0YB5`_XqZsAE_sDWby1JtMnAl?`$f+Zj)bsmAK_PfTO;auE< zkmPQJRB!VcBrF9)g}`9FsK|gcVbN>l&FXxg3JFQal^a_5D_e2ltXD3OEJZ4!8kMq6 zCvoT~E|`HnJVkc|Mkd;W*;8uccEI5hKtYJ;g9eBqZWRjf+7v)4HQTnp-ZWamDT0!S zmZ+g79lA3eCbjD+wvx{U4b9|c++wJ>l!i*?Va+)pRW8Pzod(=6(wZ}TIBmYCLYNu8 zX#_TRcx@@SAU&sEA}W%Krh#5cb8)ensKc3Dcb41jnN7+eka!KbfHsJPVgM#YPe~V` zHPmYiQ#FtUqCz3Hm6JminCSvE`AMZqsvZKr4jbh4_|E>?o<26bhpKoc!mzf(3}Rhx z*W2d9u#=;*2XQD4kA18XHG3TQFYCd>91$?qIRi#owXH6*Y&C6dmIhQ6Wh)5;Y;GJE z!3P5bvN)h(rO_8flE{!k-eZvOP&OqUWsLv}02v^LN`?xFL!fLj2OBgEL<)j1Q;G?Z za5o3XBQ1XC+eDOIK8MK=;0OnrCV^1JT;UL1PE%z7AQ{DIjtY4jk3|6dAZmhKe0kN( p#q5kMdDbPN6%!L{&dmiufp;U=f80N0WeR`A+>uTcBm_ATB!Iu=;LHF3