diff --git a/R/fortify_surv.R b/R/fortify_surv.R index 96e0182..ffea376 100644 --- a/R/fortify_surv.R +++ b/R/fortify_surv.R @@ -18,30 +18,29 @@ #' @export fortify.survfit <- function(model, data = NULL, surv.connect = FALSE, fun = NULL, ...) { - # survival package >= v3.6.1 - if (length(dim(model$n.censor)) == 2) { - model$n.censor <- rowSums(model$n.censor) + if (inherits(model, 'survfitms')) { + d <- data.frame(time = model$time, + n.risk = c(model$n.risk), + n.event = c(model$n.event), + n.censor = c(model$n.censor), + pstate = c(model$pstate), + std.err = c(model$std.err), + upper = c(model$upper), + lower = c(model$lower)) + } else { + d <- data.frame(time = model$time, + n.risk = model$n.risk, + n.event = model$n.event, + n.censor = model$n.censor, + std.err = model$std.err, + upper = model$upper, + lower = model$lower) } - d <- data.frame(time = model$time, - n.risk = model$n.risk, - n.event = model$n.event, - n.censor = model$n.censor, - std.err = model$std.err, - upper = model$upper, - lower = model$lower) - - if (is(model, 'survfit.cox')) { + + if (inherits(model, 'survfit.cox')) { d <- cbind_wraps(d, data.frame(surv = model$surv, cumhaz = model$cumhaz)) - } else if (is(model, 'survfit')) { - if (is(model, 'survfitms')) { - d <- cbind_wraps(d, data.frame(pstate = model$pstate)) - - varying.names <- c('n.risk', 'n.event', 'pstate', 'std.err', 'upper', 'lower') - varying.i <- lapply(varying.names, function(x) which(startsWith(colnames(d), x))) - d <- reshape(d, varying = varying.i, v.names = varying.names, timevar = NULL, direction = 'long') - d <- suppressWarnings(subset(d, select = -c(id))) - rownames(d) <- NULL - + } else if (inherits(model, 'survfit')) { + if (inherits(model, 'survfitms')) { if (length(model$states) > 1) { ev.names <- model$states ev.names[ev.names == ''] <- 'any' @@ -66,7 +65,12 @@ fortify.survfit <- function(model, data = NULL, surv.connect = FALSE, # connect to the origin for plotting if (surv.connect) { - base <- d[1, ] + if ('strata' %in% colnames(d)) { + base <- d[d$time == ave(d$time, d$strata, FUN = min), ] + } + if ('event' %in% colnames(d)) { + base <- d[d$time == ave(d$time, d$event, FUN = min), ] + } # cumhaz is for survfit.cox cases base[intersect(c('time', 'n.event', 'n.censor', 'std.err', 'cumhaz'), colnames(base))] <- 0 if ('pstate' %in% colnames(d)) { @@ -74,21 +78,10 @@ fortify.survfit <- function(model, data = NULL, surv.connect = FALSE, } else { base[c('surv', 'upper', 'lower')] <- 1.0 } - if ('strata' %in% colnames(d)) { - strata <- levels(d$strata) - base <- base[rep(seq_len(nrow(base)), length(strata)), ] - rownames(base) <- NULL - base$strata <- strata - base$strata <- factor(base$strata, levels = base$strata) - } if ('event' %in% colnames(d)) { - events <- levels(d$event) - base <- base[rep(seq_len(nrow(base)), length(events)), ] - rownames(base) <- NULL - base$event <- events - base$event <- factor(base$event, levels = events) - base[base$event == 'any', c('pstate', 'upper', 'lower')] <- 1.0 + base[base$event == 'any', c('pstate', 'upper', 'lower')] <- 1.0 } + rownames(base) <- NULL d <- rbind(base, d) } diff --git a/tests/testthat/test-surv.R b/tests/testthat/test-surv.R index 99c43af..0e53c3b 100644 --- a/tests/testthat/test-surv.R +++ b/tests/testthat/test-surv.R @@ -260,3 +260,21 @@ test_that('fortify.survfit regular expression for renaming strata works with mul 'std.err', 'upper', 'lower', 'strata') expect_equal(names(fortified), expected_names) }) + +test_that('n.risk at time == 0 is correct in fortify.survfit(*, surv.connect = TRUE) (#229)', { + skip_if_not_installed("survival") + library(survival) + + fit <- survfit(Surv(time, status) ~ x, data = aml) + fit_surv <- summary(fit) + fit_surv <- fit_surv$n.risk[fit_surv$time == ave(fit_surv$time, fit_surv$strata, FUN = min)] + fit_gg <- fortify.survfit(fit, surv.connect = TRUE) + fit_gg <- fit_gg[fit_gg$time == 0, "n.risk"] + expect_equal(fit_surv, fit_gg) + + fitMS <- survfit(Surv(start, stop, event) ~ 1, id = id, data = mgus1) + fitMS_surv <- unname(fitMS$n.risk[1, ]) + fitMS_gg <- fortify.survfit(fitMS, surv.connect = TRUE) + fitMS_gg <- fitMS_gg[fitMS_gg$time == 0, "n.risk"] + expect_equal(fitMS_surv, fitMS_gg) +})