Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add predicates for finding good and bad controls #104

Draft
wants to merge 18 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
11 changes: 11 additions & 0 deletions src/y0/algorithm/identify/id_std.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,17 @@ def identify(identification: Identification) -> Expression:
:returns: the expression corresponding to the identification
:raises Unidentifiable: If no appropriate identification can be found

If you have an instance of a :class:`y0.graph.NxMixedGraph` and a
query as an instance of a :class:`y0.dsl.Probability`, use the following:

.. code-block:: python

from y0.algorithm.identify import identify, Identification
graph = ...
query = ...
identification = Identification.from_expression(graph=graph, query=query)
estimand = identify(identification)

See also :func:`identify_outcomes` for a more idiomatic way of running
the ID algorithm given a graph, treatments, and outcomes.
"""
Expand Down
158 changes: 158 additions & 0 deletions src/y0/controls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,158 @@
"""Predicates for good, bad, and neutral controls from [cinelli2022]_.

.. [cinelli2022]: `A Crash Course in Good and Bad Controls <https://ftp.cs.ucla.edu/pub/stat_ser/r493.pdf>`_
"""

from .algorithm.conditional_independencies import are_d_separated
from .algorithm.identify import identify_outcomes
from .dsl import Expression, Fraction, One, Probability, Product, Sum, Variable, Zero
from .graph import NxMixedGraph

__all__ = [
"is_bad_control",
"is_good_control",
"is_outcome_ancestor",
"is_middle_mediator",
]


def _control_precondition(
graph: NxMixedGraph, cause: Variable, effect: Variable, variable: Variable
):
if cause not in graph.nodes():
raise ValueError(f"Cause variable missing: {variable}")
if effect not in graph.nodes():
raise ValueError(f"Effect variable missing: {variable}")
if variable not in graph.nodes():
raise ValueError(f"Test variable missing: {variable}")
# TODO does this need to be extended to check that the
# query and variable aren't counterfactual?


def is_bad_control(
graph: NxMixedGraph, cause: Variable, effect: Variable, variable: Variable
) -> bool:
"""Return if the variable is a bad control.

A bad control is a variable that does not appear in the estimand produced
by :func:`y0.algorithm.identify.identify` when applied to a given graph
and query.

:param graph: An ADMG
:param cause: The intervention in the causal query
:param effect: The outcome of the causal query
:param variable: The variable to check
:return: If the variable is a bad control
"""
_control_precondition(graph, cause, effect, variable)
estimand = identify_outcomes(graph, cause, effect)
return estimand is None or not _in_expression(variable, estimand)


def _in_expression(var: Variable, expr: Expression) -> bool:
if isinstance(expr, Variable):
return expr == var
elif isinstance(expr, Probability):
return var in expr.children or var in expr.parents
elif isinstance(expr, Fraction):
return _in_expression(var, expr.numerator) or _in_expression(var, expr.denominator)
elif isinstance(expr, Zero | One):
return False
elif isinstance(expr, Sum):
return _in_expression(var, expr.expression)
elif isinstance(expr, Product):
return any(_in_expression(var, subexpr) for subexpr in expr.expressions)
raise NotImplementedError


def is_good_control(
graph: NxMixedGraph, cause: Variable, effect: Variable, variable: Variable
) -> bool:
"""Return if the variable is a good control.

:param graph: An ADMG
:param cause: The intervention in the causal query
:param effect: The outcome of the causal query
:param variable: The variable to check
:return: If the variable is a good control
"""
_control_precondition(graph, cause, effect, variable)
raise NotImplementedError


def is_outcome_ancestor(
graph: NxMixedGraph, cause: Variable, effect: Variable, variable: Variable
) -> bool:
"""Check if the variable is an outcome ancestor given a causal query and graph.

> In Model 8, Z is not a confounder nor does it block any back-door paths. Likewise,
controlling for Z does not open any back-door paths from X to Y . Thus, in terms of
asymptotic bias, Z is a “neutral control.” Analysis shows, however, that controlling for
Z reduces the variation of the outcome variable Y , and helps to improve the precision
of the ACE estimate in finite samples (Hahn, 2004; White and Lu, 2011; Henckel et al.,
2019; Rotnitzky and Smucler, 2019).

:param graph: An ADMG
:param cause: The intervention in the causal query
:param effect: The outcome of the causal query
:param variable: The variable to check
:return: If the variable is a bad control
"""
if variable == cause:
return False
judgement = are_d_separated(graph, cause, variable)
return judgement.separated and variable in graph.ancestors_inclusive(effect)


def is_middle_mediator(graph: NxMixedGraph, x: Variable, y: Variable, z: Variable) -> bool:
"""Check if the variable Z is a middle mediator.

> At first look, Model 13 might seem similar to Model 12, and one may think that
adjusting for Z would bias the effect estimate, by restricting variations of th
mediator M. However, the key difference here is that Z is a cause, not an effect,
of the mediator (and, consequently, also a cause of Y ). Thus, Model 13 is analogous
to Model 8, and so controlling for Z will be neutral in terms of bias and may increase
the precision of the ACE estimate in finite samples. Readers can find further
discussion of this case in Pearl (2013).

From Jeremy:
Figure 9: Model 13: If exists M such that X is Ancestor of M and M is
ancestor of Y and Z is an ancestor of M and Z _|_ Y | M and Z _|_ X

Strategy:

1. First implement the helper function below:
1. Use :func:`y0.algorithm.conditional_independencies.are_d_separated` with the ``given`` argument
to check Z _|_ Y | M
2. Use :func:`y0.algorithm.conditional_independencies.are_d_separated` to check Z _|_ X
3. Magic
4. Small profit
2. Check if the helper returns true for any possible mediator M (loop over all variables, naive implementation ftw)
3. Big Profit
"""
_control_precondition(graph, x, y, z)
return any(_middle_mediator_helper(graph, x, y, z, mediator) for mediator in graph)


def _middle_mediator_helper(
graph: NxMixedGraph, x: Variable, y: Variable, z: Variable, m: Variable
) -> bool:
"""Help check for mediator M.

:param x: The cause/intervention
:param y: The effect/outcome
:param z: The variable we're checking if it's a good control
:param m: The mediator
"""
x_ancestor_m = graph.is_ancestor_of(x, m)
m_ancestor_y = graph.is_ancestor_of(m, y)
z_ancestor_m = graph.is_ancestor_of(z, m)
yz_are_d_separated_given_m = are_d_separated(graph, y, z, conditions=[m])
xz_are_d_separated = are_d_separated(graph, x, z)
return (
x_ancestor_m
and m_ancestor_y
and z_ancestor_m
and yz_are_d_separated_given_m
and xz_are_d_separated
)
4 changes: 4 additions & 0 deletions src/y0/graph.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,6 +560,10 @@ def ancestors_inclusive(self, sources: Variable | Iterable[Variable]) -> set[Var
sources = _ensure_set(sources)
return _ancestors_inclusive(self.directed, sources)

def is_ancestor_of(self, ancestor: Variable, descendant: Variable) -> bool:
"""Check if one variable is an ancestor of another."""
return ancestor in self.ancestors_inclusive(descendant)

def descendants_inclusive(self, sources: Variable | Iterable[Variable]) -> set[Variable]:
"""Descendants of a set include the set itself."""
sources = _ensure_set(sources)
Expand Down
115 changes: 115 additions & 0 deletions tests/test_controls.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,115 @@
"""Tests for good, bad, and neutral controls."""

import unittest

from y0.controls import (
is_bad_control,
is_good_control,
is_middle_mediator,
is_outcome_ancestor,
)
from y0.dsl import U1, U2, M, U, W, X, Y, Z, Variable
from y0.graph import NxMixedGraph

model_1 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, Y), (X, Y)])
model_2 = NxMixedGraph.from_edges(directed=[(U, Z), (Z, X), (X, Y), (U, Y)])
model_3 = NxMixedGraph.from_edges(directed=[(U, X), (U, Z), (Z, Y), (X, Y)])
model_4 = NxMixedGraph.from_edges(directed=[(Z, X), (Z, M), (X, M), (M, Y)])
model_5 = NxMixedGraph.from_edges(directed=[(U, Z), (Z, X), (U, M), (X, M), (M, Y)])
model_6 = NxMixedGraph.from_edges(directed=[(U, X), (U, Z), (Z, M), (X, M), (M, Y)])

good_test_models = [
("good model 1", model_1),
("good model 2", model_2),
("good model 3", model_3),
("good model 4", model_4),
("good model 5", model_5),
("good model 6", model_6),
]

# bad control, M-bias
model_7 = NxMixedGraph.from_edges(directed=[(U1, Z), (U2, Z), (U1, X), (U2, Y), (X, Y)])
# bad control, Bias amplification
model_10 = NxMixedGraph.from_edges(directed=[(Z, X), (U, X), (U, Y), (X, Y)])
# bad control
model_11 = NxMixedGraph.from_edges(directed=[(X, Z), (Z, Y)])
model_11_variation = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (Z, Y), (U, Y)])
model_12 = NxMixedGraph.from_edges(directed=[(X, M), (M, Y), (M, Z)])
# bad control, Selection bias
model_16 = NxMixedGraph.from_edges(directed=[(X, Z), (U, Z), (U, Y), (X, Y)])
model_17 = NxMixedGraph.from_edges(directed=[(X, Z), (Y, Z), (X, Y)])
# bad control, case-control bias
model_18 = NxMixedGraph.from_edges(directed=[(X, Y), (Y, Z)])

bad_test_models = [
("bad model 7", model_7),
("bad model 10", model_10),
("bad model 11", model_11),
("bad model 11 (variation)", model_11_variation),
("bad model 12", model_12),
("bad model 16", model_16),
("bad model 17", model_17),
("bad model 18", model_18),
]

# neutral control, possibly good for precision
model_8 = NxMixedGraph.from_edges(directed=[(X, Y), (Z, Y)])
# neutral control, possibly bad for precision
model_9 = NxMixedGraph.from_edges(directed=[(Z, X), (X, Y)])
# neutral control, possibly good for precision
model_13 = NxMixedGraph.from_edges(directed=[(X, W), (Z, W), (W, Y)])
# neutral control, possibly helpful in the case of selection bias
model_14 = NxMixedGraph.from_edges(directed=[(X, Y), (X, Z)])
model_15 = NxMixedGraph.from_edges(directed=[(X, Z), (Z, W), (X, Y), (U, W), (U, Y)])

neutral_test_models = [
("neutral model 8", model_8),
("neutral model 9", model_9),
("neutral model 13", model_13),
("neutral model 14", model_14),
("neutral model 15", model_15),
]


class TestControls(unittest.TestCase):
"""Test case for good, bad, and neutral controls."""

def test_preconditions(self):
"""Test the preconditions are checked properly for good controls."""
irrelevant_variable = Variable("nopenopennope")
for func in is_good_control, is_bad_control:
with self.subTest(name=func.__name__):
with self.assertRaises(ValueError):
func(model_1, X, Y, irrelevant_variable)

def test_good_controls(self):
"""Test good controls."""
for name, model in good_test_models:
with self.subTest(name=name):
self.assertTrue(is_good_control(model, X, Y, Z))
for name, model in bad_test_models + neutral_test_models:
with self.subTest(name=name):
self.assertFalse(is_good_control(model, X, Y, Z))

def test_bad_controls(self):
"""Test bad controls."""
for name, model in good_test_models + neutral_test_models:
with self.subTest(name=name):
self.assertFalse(is_bad_control(model, X, Y, Z))
for name, model in bad_test_models:
with self.subTest(name=name):
self.assertTrue(is_bad_control(model, X, Y, Z))

def test_outcome_ancestors(self):
"""Test outcome ancestors."""
self.assertTrue(is_outcome_ancestor(model_8, X, Y, Z))
for name, model in good_test_models + bad_test_models:
with self.subTest(name=name):
self.assertFalse(is_outcome_ancestor(model, X, Y, Z))

def test_middle_mediators(self):
"""Test middle mediators."""
self.assertTrue(is_middle_mediator(model_13, X, Y, Z))
for name, model in good_test_models + bad_test_models:
with self.subTest(name=name):
self.assertFalse(is_middle_mediator(model, X, Y, Z))
Loading