diff --git a/R/stratifyCohorts.R b/R/stratifyCohorts.R index db5a0b26..5a98e425 100644 --- a/R/stratifyCohorts.R +++ b/R/stratifyCohorts.R @@ -49,36 +49,31 @@ stratifyCohorts <- function(cohort, cdm <- omopgenerics::cdmReference(cohort) if (length(strata) == 0 | sum(cohortCount(cohort)$number_records) == 0) { - if (identical(name, tableName(cohort))) { - return(cohort) - } else { - return( - cohort |> - dplyr::compute(name = name, temporary = FALSE) |> - omopgenerics::newCohortTable(.softValidation = TRUE) - ) - } + return( + subsetCohorts(cohort = cohort, cohortId = cohortId, name = name) + ) } strataCols <- unique(unlist(strata)) set <- settings(cohort) |> - dplyr::filter(.data$cohort_definition_id %in% .env$cohortId) |> - dplyr::mutate("target_cohort_table_name" = tableName(cohort)) |> - dplyr::rename( - "target_cohort_id" = "cohort_definition_id", - "target_cohort_name" = "cohort_name" - ) - + dplyr::filter(.data$cohort_definition_id %in% .env$cohortId) # drop columns from set - dropCols <- colnames(set)[colnames(set) %in% strataCols] + dropCols <- colnames(set)[colnames(set) %in% c( + strataCols, "target_cohort_id", "target_cohort_name", "target_cohort_table_name", "strata_columns")] if (length(dropCols) > 0) { cli::cli_inform(c( - "!" = "{dropCols} {?is/are} present in settings and strata. Settings - column will be not considered." + "!" = "{dropCols} {?is/are} will be overwritten in settings." )) set <- set |> dplyr::select(!dplyr::all_of(dropCols)) } + set <- set |> + dplyr::mutate("target_cohort_table_name" = tableName(cohort)) |> + dplyr::rename( + "target_cohort_id" = "cohort_definition_id", + "target_cohort_name" = "cohort_name" + ) + # get counts for attrition counts <- cohort |> diff --git a/R/subsetCohorts.R b/R/subsetCohorts.R index 2c8efac3..0c0cdd6f 100644 --- a/R/subsetCohorts.R +++ b/R/subsetCohorts.R @@ -29,7 +29,7 @@ subsetCohorts <- function(cohort, minCohortCount = 0, name = tableName(cohort)) { # checks - cohort <- validateCohortTable(cohort, TRUE) + cohort <- validateCohortTable(cohort) cohortId <- validateCohortId(cohortId, settings(cohort)$cohort_definition_id) name <- validateName(name) minCohortCount <- validateN(minCohortCount) diff --git a/tests/testthat/test-stratifyCohorts.R b/tests/testthat/test-stratifyCohorts.R index 1d76f40d..623cc172 100644 --- a/tests/testthat/test-stratifyCohorts.R +++ b/tests/testthat/test-stratifyCohorts.R @@ -49,6 +49,7 @@ test_that("simple stratification", { cdm$new_cohort <- cdm$cohort1 |> stratifyCohorts( strata = list(c("blood_type", "age_group"), "sex"), + removeStrata = FALSE, name = "new_cohort" ) ) @@ -73,6 +74,32 @@ test_that("simple stratification", { expect_true(all(attritionCdi$number_subjects == c(2, 2, 1, 2, 1, 0, 1, 0, 0))) expect_true(all(attritionCdi$excluded_records == c(0, 1, 1, 0, 1, 1, 0, 1, 0))) expect_true(all(attritionCdi$excluded_subjects == c(0, 0, 1, 0, 1, 1, 0, 1, 0))) + expect_equal( + colnames(cdm$new_cohort), + c('cohort_definition_id', 'subject_id', 'cohort_start_date', 'cohort_end_date', + 'extra_column', 'blood_type', 'sex', 'age_group') + ) + + # test settings drop columns + expect_message( + cdm$new_cohort2 <- cdm$new_cohort |> + stratifyCohorts( + strata = list(c("blood_type", "age_group"), "sex"), + name = "new_cohort2" + ) + ) + expect_equal( + colnames(cdm$new_cohort2), + c('cohort_definition_id', 'subject_id', 'cohort_start_date', 'cohort_end_date', 'extra_column') + ) + + cdm$new_cohort3 <- cdm$new_cohort |> + stratifyCohorts( + cohortId = 1, + strata = list(), + name = "new_cohort3" + ) + expect_equal(collectCohort(cdm$new_cohort2, 1), collectCohort(cdm$new_cohort3, 1)) # empty cohort cdm <- omopgenerics::emptyCohortTable(cdm, "empty_cohort") diff --git a/tests/testthat/test-subsetCohorts.R b/tests/testthat/test-subsetCohorts.R index e672815f..2ffeb67c 100644 --- a/tests/testthat/test-subsetCohorts.R +++ b/tests/testthat/test-subsetCohorts.R @@ -187,8 +187,11 @@ test_that("Testing minCohortCount argument", { cdm$sub2 <- cdm$cohort1 |> subsetCohorts(cohortId = 4, name = "sub2") expect_equal(settings(cdm$sub2), dplyr::tibble(cohort_definition_id = 4, cohort_name = "cohort_4")) - cdm$sub3 <- cdm$cohort1 |> subsetCohorts(cohortId = 4, minCohortCount = 1, name = "sub3") + cdm$sub3 <- cdm$cohort1 |> + dplyr::mutate(extra_col = 1) |> + subsetCohorts(cohortId = 4, minCohortCount = 1, name = "sub3") expect_true(nrow(settings(cdm$sub3)) == 0) + expect_true("extra_col" %in% colnames(cdm$sub3)) PatientProfiles::mockDisconnect(cdm) })