diff --git a/R/check_dag.R b/R/check_dag.R index 9bc9bcd8c..3d0ae3c2f 100644 --- a/R/check_dag.R +++ b/R/check_dag.R @@ -312,15 +312,13 @@ check_dag <- function(..., adjustment_set <- unlist(dagitty::adjustmentSets(dag, effect = x), use.names = FALSE) adjustment_nodes <- unlist(dagitty::adjustedNodes(dag), use.names = FALSE) minimal_adjustments <- as.list(dagitty::adjustmentSets(dag, effect = x)) - collider <- adjustment_nodes[vapply(adjustment_nodes, ggdag::is_collider, logical(1), .dag = dag, downstream = FALSE)] - if (!length(collider)) { + collider <- adjustment_nodes[vapply(adjustment_nodes, ggdag::is_collider, logical(1), .dag = dag, downstream = FALSE)] # nolint + if (length(collider)) { + # if we *have* colliders, remove them from minimal adjustments + minimal_adjustments <- lapply(minimal_adjustments, setdiff, y = collider) + } else { # if we don't have colliders, set to NULL collider <- NULL - } else { - # if we *have* colliders, remove them from minimal adjustments - minimal_adjustments <- lapply(minimal_adjustments, function(ma) { - setdiff(ma, collider) - }) } list( # no adjustment needed when @@ -330,7 +328,7 @@ check_dag <- function(..., # incorrect adjustment when # - required is NULL and current adjustment not NULL # - OR we have a collider in current adjustments - incorrectly_adjusted = (is.null(adjustment_set) && !is.null(adjustment_nodes)) || (!is.null(collider) && collider %in% adjustment_nodes), + incorrectly_adjusted = (is.null(adjustment_set) && !is.null(adjustment_nodes)) || (!is.null(collider) && collider %in% adjustment_nodes), # nolint current_adjustments = adjustment_nodes, minimal_adjustments = minimal_adjustments, collider = collider