diff --git a/R/conceptCohort.R b/R/conceptCohort.R index a3037e0..b0ba082 100644 --- a/R/conceptCohort.R +++ b/R/conceptCohort.R @@ -74,6 +74,11 @@ conceptCohort <- function(cdm, omopgenerics::assertChoice(exit, c("event_start_date", "event_end_date")) omopgenerics::assertChoice(overlap, c("merge", "extend"), length = 1) omopgenerics::assertLogical(useSourceFields, length = 1) + omopgenerics::assertCharacter(subsetCohort, length = 1, null = TRUE) + if (!is.null(subsetCohort)) { + subsetCohort <- omopgenerics::validateCohortArgument(cdm[[subsetCohort]]) + subsetCohortId <- omopgenerics::validateCohortIdArgument({{subsetCohortId}}, subsetCohort) + } useIndexes <- getOption("CohortConstructor.use_indexes") @@ -108,19 +113,13 @@ conceptCohort <- function(cdm, # subsetCohort if (!is.null(subsetCohort)) { - subsetCohort <- omopgenerics::validateCohortArgument(subsetCohort) - subsetCohortId <- omopgenerics::validateCohortIdArgument(subsetCohortId, cohort = subsetCohort) subsetName <- omopgenerics::uniqueTableName(prefix = tmpPref) - if (!all(settings(subsetCohort)$cohort_definition_id %in% subsetCohortId)) { - subsetCohort <- subsetCohort |> - dplyr::filter(.data$cohort_definition_id %in% .env$subsetCohortId) |> - dplyr::compute(name = subsetName, temporary = FALSE) - } subsetIndividuals <- subsetCohort |> + dplyr::filter(.data$cohort_definition_id %in% .env$subsetCohortId) |> dplyr::distinct(.data$subject_id) |> dplyr::compute(name = subsetName, temporary = FALSE) - if (subsetIndividuals |> dplyr::tally() |> dplyr::pull("n") == 0) { - omopgenerics::dropTable(cdm = cdm, name = dplyr::starts_with(tmpPref)) + if (omopgenerics::isTableEmpty(subsetIndividuals)) { + omopgenerics::dropTable(cdm = cdm, name = subsetName) cli::cli_abort("There are no individuals in the `subsetCohort` and `subsetCohortId` provided.") } if (!isFALSE(useIndexes)) { diff --git a/tests/testthat/test-conceptCohort.R b/tests/testthat/test-conceptCohort.R index fd618db..2464f23 100644 --- a/tests/testthat/test-conceptCohort.R +++ b/tests/testthat/test-conceptCohort.R @@ -929,4 +929,83 @@ test_that("test indexes - postgres", { CDMConnector::cdm_disconnect(cdm = cdm) }) +test_that("test subsetCohort arguments", { + cdm <- omock::mockCdmFromTables( + tables = list( + condition_occurrence = dplyr::tibble( + condition_occurrence_id = 1:3L, + person_id = c(1L, 2L, 3L), + condition_concept_id = 194152L, + condition_start_date = as.Date("2020-01-01"), + condition_end_date = as.Date("2020-01-01"), + condition_type_concept_id = 0L + ), + cohort = dplyr::tibble( + subject_id = c(1L, 2L), + cohort_definition_id = c(1L, 2L), + cohort_start_date = as.Date("2010-01-01"), + cohort_end_date = as.Date("2010-01-01") + ) + ) + ) + + cdm <- CDMConnector::copyCdmTo(con = duckdb::dbConnect(duckdb::duckdb()), cdm = cdm, schema = "main") + expect_no_error( + x <- conceptCohort( + cdm = cdm, + conceptSet = list(test = 194152L), + name = "test" + ) + ) + expect_true(all(c(1L, 2L, 3L) %in% dplyr::pull(x, "subject_id"))) + + expect_no_error( + x <- conceptCohort( + cdm = cdm, + conceptSet = list(test = 194152L), + name = "test", + subsetCohort = "cohort" + ) + ) + expect_true(all(c(1L, 2L) %in% dplyr::pull(x, "subject_id"))) + expect_true(all(!c(3L) %in% dplyr::pull(x, "subject_id"))) + + expect_no_error( + x <- conceptCohort( + cdm = cdm, + conceptSet = list(test = 194152L), + name = "test", + subsetCohort = "cohort", + subsetCohortId = 1L + ) + ) + expect_true(all(c(1L) %in% dplyr::pull(x, "subject_id"))) + expect_true(all(!c(2L, 3L) %in% dplyr::pull(x, "subject_id"))) + + expect_no_error( + x <- conceptCohort( + cdm = cdm, + conceptSet = list(test = 194152L), + name = "test", + subsetCohort = "cohort", + subsetCohortId = "cohort_1" + ) + ) + expect_true(all(c(1L) %in% dplyr::pull(x, "subject_id"))) + expect_true(all(!c(2L, 3L) %in% dplyr::pull(x, "subject_id"))) + + expect_no_error( + x <- conceptCohort( + cdm = cdm, + conceptSet = list(test = 194152L), + name = "test", + subsetCohort = "cohort", + subsetCohortId = dplyr::starts_with("cohort") + ) + ) + expect_true(all(c(1L, 2L) %in% dplyr::pull(x, "subject_id"))) + expect_true(all(!c(3L) %in% dplyr::pull(x, "subject_id"))) + + CDMConnector::cdmDisconnect(cdm = cdm) +})