From 1f3e06656a086b15ab8bc2dab7c8c73b62820e7b Mon Sep 17 00:00:00 2001 From: Martin Date: Mon, 25 Nov 2024 15:26:56 +0100 Subject: [PATCH 1/2] bugfix edge case where party returns an constparty object --- R/approach_ctree.R | 33 ++++++++++++++++++++++----------- 1 file changed, 22 insertions(+), 11 deletions(-) diff --git a/R/approach_ctree.R b/R/approach_ctree.R index 86e8b5e97..d4faedf6e 100644 --- a/R/approach_ctree.R +++ b/R/approach_ctree.R @@ -240,21 +240,32 @@ sample_ctree <- function(tree, dependent_ind <- tree$dependent_ind x_explain_given <- x_explain[, - given_ind, - drop = FALSE, - with = FALSE + given_ind, + drop = FALSE, + with = FALSE ] # xp <- x_explain_given colnames(xp) <- paste0("V", given_ind) # this is important for where() below if (using_partykit) { + + # xp here needs to contain the response variables as well, for some reason + x_explain_dependent <- x_explain[, + dependent_ind, + drop = FALSE, + with = FALSE + ] + + colnames(x_explain_dependent) <- paste0("Y", seq_along(dependent_ind)) + xp2 <- cbind(xp, x_explain_dependent) + fit.nodes <- predict( object = datact, type = "node" ) # newdata must be data.frame + have the same colnames as x pred.nodes <- predict( - object = datact, newdata = xp, + object = datact, newdata = xp2, type = "node" ) } else { @@ -271,20 +282,20 @@ sample_ctree <- function(tree, newrowno <- rowno[fit.nodes == pred.nodes] } else { newrowno <- sample(rowno[fit.nodes == pred.nodes], n_MC_samples, - replace = TRUE + replace = TRUE ) } depDT <- data.table::data.table(x_train[newrowno, - dependent_ind, - drop = FALSE, - with = FALSE + dependent_ind, + drop = FALSE, + with = FALSE ]) givenDT <- data.table::data.table(x_explain[1, - given_ind, - drop = FALSE, - with = FALSE + given_ind, + drop = FALSE, + with = FALSE ]) ret <- cbind(depDT, givenDT) data.table::setcolorder(ret, colnames(x_train)) From d7463ef18e19468ab78d4b9f25c3b1c5ff6c0303 Mon Sep 17 00:00:00 2001 From: Martin Date: Mon, 25 Nov 2024 17:17:07 +0100 Subject: [PATCH 2/2] style --- R/approach_ctree.R | 27 +++++++++++++-------------- 1 file changed, 13 insertions(+), 14 deletions(-) diff --git a/R/approach_ctree.R b/R/approach_ctree.R index d4faedf6e..a5c8de0a9 100644 --- a/R/approach_ctree.R +++ b/R/approach_ctree.R @@ -240,20 +240,19 @@ sample_ctree <- function(tree, dependent_ind <- tree$dependent_ind x_explain_given <- x_explain[, - given_ind, - drop = FALSE, - with = FALSE + given_ind, + drop = FALSE, + with = FALSE ] # xp <- x_explain_given colnames(xp) <- paste0("V", given_ind) # this is important for where() below if (using_partykit) { - # xp here needs to contain the response variables as well, for some reason x_explain_dependent <- x_explain[, - dependent_ind, - drop = FALSE, - with = FALSE + dependent_ind, + drop = FALSE, + with = FALSE ] colnames(x_explain_dependent) <- paste0("Y", seq_along(dependent_ind)) @@ -282,20 +281,20 @@ sample_ctree <- function(tree, newrowno <- rowno[fit.nodes == pred.nodes] } else { newrowno <- sample(rowno[fit.nodes == pred.nodes], n_MC_samples, - replace = TRUE + replace = TRUE ) } depDT <- data.table::data.table(x_train[newrowno, - dependent_ind, - drop = FALSE, - with = FALSE + dependent_ind, + drop = FALSE, + with = FALSE ]) givenDT <- data.table::data.table(x_explain[1, - given_ind, - drop = FALSE, - with = FALSE + given_ind, + drop = FALSE, + with = FALSE ]) ret <- cbind(depDT, givenDT) data.table::setcolorder(ret, colnames(x_train))