-
Notifications
You must be signed in to change notification settings - Fork 10
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Closes #120 This PR adds a high-level interface and implements tests for sigma-separation, a generalization of d-separation that works not only for directed acyclic graphs, but also for directed graphs containing cycles. It was originally introduced in > Constraint-based Causal Discovery for Non-Linear Structural Causal Models with Cycles and Latent Confounders > Forré and Mooij. 2019. [arXiv:1807.03024](https://arxiv.org/abs/1807.03024) and is an integral part of cyclic ID algorithm (see #71) and the gID algorithm (see #72) ## References/Notes - https://stats.stackexchange.com/questions/586810/sigma-separation-question-in-cyclic-causal-graph-understanding-sigma-separatio - Author's implementation: https://github.com/caus-am/sigmasep
- Loading branch information
Showing
2 changed files
with
394 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,271 @@ | ||
"""Implementation of sigma-separation.""" | ||
|
||
from typing import Iterable, Optional, Sequence | ||
|
||
import networkx as nx | ||
from more_itertools import triplewise | ||
|
||
from y0.dsl import Variable | ||
from y0.graph import NxMixedGraph | ||
|
||
__all__ = [ | ||
"are_sigma_separated", | ||
"is_z_sigma_open", | ||
"get_equivalence_classes", | ||
] | ||
|
||
|
||
def are_sigma_separated( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
right: Variable, | ||
*, | ||
conditions: Optional[Iterable[Variable]] = None, | ||
cutoff: Optional[int] = None, | ||
) -> bool: | ||
"""Test if two variables are sigma-separated. | ||
Sigma separation is a generalization of d-separation that | ||
works not only for directed acyclic graphs, but also for | ||
directed graphs containing cycles. It was originally introduced | ||
in https://arxiv.org/abs/1807.03024. | ||
We say that X and Y are σ-connected by Z or not | ||
σ-separated by Z if there exists a path π (with some | ||
n ≥ 1 nodes) in G with one endnode in X and | ||
one endnode in Y that is Z-σ-open. σ-separated is the | ||
opposite of σ-connected (logical not). | ||
:param graph: Graph to test | ||
:param left: A node in the graph | ||
:param right: A node in the graph | ||
:param conditions: A collection of graph nodes | ||
:param cutoff: The maximum path length to check. By default, is unbounded. | ||
:return: If a and b are sigma-separated. | ||
""" | ||
if conditions is None: | ||
conditions = set() | ||
else: | ||
conditions = set(conditions) | ||
|
||
sigma = get_equivalence_classes(graph) | ||
return not any( | ||
is_z_sigma_open(graph, path, conditions=conditions, sigma=sigma) | ||
# Technically, this algorithm should generate all paths, which could include | ||
# repeat visits to nodes and edges, but this is computationally intractable, | ||
# so the is_z_sigma_open() subroutine contains a novel path augmentation | ||
# algorithm. This might not be officially complete. | ||
for path in nx.all_simple_paths(graph.disorient(), left, right, cutoff=cutoff) | ||
) | ||
|
||
|
||
def is_z_sigma_open( | ||
graph: NxMixedGraph, | ||
path: Sequence[Variable], | ||
*, | ||
sigma: dict[Variable, set[Variable]], | ||
conditions: Optional[set[Variable]] = None, | ||
) -> bool: | ||
r"""Check if a path is Z-sigma-open. | ||
:param graph: A mixed graph | ||
:param path: A path in the graph. Denoted as $\pi$ in the paper. The | ||
node in position $i$ in the path is denoted with $v_i$. | ||
:param conditions : A set of nodes chosen as conditions, denoted by $Z$ in the paper | ||
:param sigma: The set of equivalence classes. Can be calculated with | ||
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. | ||
:returns: If the path is Z-sigma-open | ||
A path is $Z-\sigma-\text{open}$ if: | ||
1. The end nodes $v_1, v_n \notin Z$ | ||
2. Every triple of adjacent nodes in the path is of the form: | ||
1. Collider (:func:`is_collider`) | ||
2. (non-collider) left chain (:func:`is_non_collider_left_chain`) | ||
3. (non-collider) right chain (:func:`is_non_collider_left_chain`) | ||
4. (non-collider) fork (:func:`is_non_collider_fork`) | ||
5. (non-collider) with undirected edge (:func:`is_non_collider_undirected`, not implemented) | ||
""" | ||
if conditions is None: | ||
conditions = set() | ||
if path[0] in conditions or path[-1] in conditions: | ||
return False | ||
return all( | ||
_triple_has_correct_form(graph, left, middle, right, conditions, sigma) | ||
for left, middle, right in triplewise(path) | ||
) | ||
|
||
|
||
def _triple_has_correct_form( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
sigma: dict[Variable, set[Variable]], | ||
) -> bool: | ||
if _triple_helper(graph, left, middle, right, conditions, sigma): | ||
return True | ||
# augment with backtracks, since you're allowed to go back (just like Season 5 of Lost). | ||
# this is a better solution than generating infinite paths, but might still be mathematically | ||
# incomplete. In this setup, 𝑣3→𝑣4↔𝑣6 becomes 𝑣3→𝑣4→𝑣5←𝑣4↔𝑣6 to get some sweet backtrack paths | ||
# through the middle node to a neighbor and then back before going to the right node. | ||
neighbors = {n for n in graph.disorient().neighbors(middle) if n != middle} | ||
for neighbor in neighbors: | ||
if ( | ||
_triple_helper(graph, left, middle, neighbor, conditions, sigma) | ||
and _triple_helper(graph, middle, neighbor, middle, conditions, sigma) | ||
and _triple_helper(graph, neighbor, middle, right, conditions, sigma) | ||
): | ||
return True | ||
return False | ||
|
||
|
||
def _triple_helper( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
sigma: dict[Variable, set[Variable]], | ||
) -> bool: | ||
return ( | ||
is_collider(graph, left, middle, right, conditions) | ||
or is_non_collider_left_chain(graph, left, middle, right, conditions, sigma) | ||
or is_non_collider_right_chain(graph, left, middle, right, conditions, sigma) | ||
or is_non_collider_fork(graph, left, middle, right, conditions, sigma) | ||
) | ||
|
||
|
||
def _has_either_edge(graph: NxMixedGraph, u, v) -> bool: | ||
return graph.directed.has_edge(u, v) or graph.undirected.has_edge(u, v) | ||
|
||
|
||
def _only_directed_edge(graph, u, v) -> bool: | ||
return graph.directed.has_edge(u, v) and not graph.undirected.has_edge(u, v) | ||
|
||
|
||
def is_collider( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
) -> bool: | ||
"""Check if three nodes form a collider under the given conditions. | ||
:param graph: A mixed graph | ||
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper | ||
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper | ||
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper | ||
:param conditions: The conditional variables, denoted as $Z$ in the paper | ||
:return: If the three nodes form a collider | ||
""" | ||
return ( | ||
_has_either_edge(graph, left, middle) | ||
and _has_either_edge(graph, right, middle) | ||
and middle in conditions | ||
) | ||
|
||
|
||
def is_non_collider_left_chain( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
sigma: dict[Variable, set[Variable]], | ||
) -> bool: | ||
r"""Check if three nodes form a non-collider (left chain) given the conditions. | ||
:param graph: A mixed graph | ||
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper | ||
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper | ||
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper | ||
:param conditions: The conditional variables, denoted as $Z$ in the paper | ||
:param sigma: The set of equivalence classes. Can be calculated with | ||
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. | ||
:return: If the three nodes form a non-collider (left chain) given the conditions. | ||
""" | ||
return ( | ||
_only_directed_edge(graph, middle, left) | ||
and _has_either_edge(graph, right, middle) | ||
and (middle not in conditions or middle in conditions.intersection(sigma[left])) | ||
) | ||
|
||
|
||
def is_non_collider_right_chain( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
sigma: dict[Variable, set[Variable]], | ||
) -> bool: | ||
r"""Check if three nodes form a non-collider (right chain) given the conditions. | ||
:param graph: A mixed graph | ||
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper | ||
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper | ||
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper | ||
:param conditions: The conditional variables, denoted as $Z$ in the paper | ||
:param sigma: The set of equivalence classes. Can be calculated with | ||
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. | ||
:return: If the three nodes form a non-collider (right chain) given the conditions. | ||
""" | ||
return ( | ||
_has_either_edge(graph, left, middle) | ||
and _only_directed_edge(graph, middle, right) | ||
and (middle not in conditions or middle in conditions.intersection(sigma[right])) | ||
) | ||
|
||
|
||
def is_non_collider_fork( | ||
graph: NxMixedGraph, | ||
left: Variable, | ||
middle: Variable, | ||
right: Variable, | ||
conditions: set[Variable], | ||
sigma: dict[Variable, set[Variable]], | ||
) -> bool: | ||
r"""Check if three nodes form a non-collider (fork) given the conditions. | ||
:param graph: A mixed graph | ||
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper | ||
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper | ||
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper | ||
:param conditions: The conditional variables, denoted as $Z$ in the paper | ||
:param sigma: The set of equivalence classes. Can be calculated with | ||
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. | ||
:return: If the three nodes form a non-collider (fork) given the conditions. | ||
""" | ||
a = _only_directed_edge(graph, middle, left) | ||
b = _only_directed_edge(graph, middle, right) | ||
c = middle not in conditions | ||
d = middle in conditions.intersection(sigma[left]).intersection(sigma[right]) | ||
return a and b and (c or d) | ||
|
||
|
||
def get_equivalence_classes(graph: NxMixedGraph) -> dict[Variable, set[Variable]]: | ||
"""Get equivalence classes. | ||
:param graph: A mixed graph | ||
:returns: A mapping from variables to their equivalence class, | ||
defined as the second option from the paper (see below) | ||
1. The finest/trivial σ-CG structure of | ||
a mixed graph G is given by σ(v) := {v} for all | ||
v ∈ V . In this way σ-separation in G coincides with | ||
the usual notion of d-separation in a d-connection | ||
graph (d-CG) G (see [19]). We will take this as the | ||
definition of d-separation and d-CG in the following. | ||
2. The coarsest σ-CG structure of a mixed graph G is | ||
given by σ(v) := ScG(v) := AncG(v) ∩ DescG(v) | ||
w.r.t. the underlying directed graph. Note that the | ||
definition of strongly connected component totally | ||
ignores the bi- and undirected edges of the σ-CG. | ||
""" | ||
return { | ||
node: graph.ancestors_inclusive(node).intersection(graph.descendants_inclusive(node)) | ||
for node in graph.nodes() | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,123 @@ | ||
"""Test sigma separation.""" | ||
|
||
import unittest | ||
|
||
from y0.algorithm.conditional_independencies import are_d_separated | ||
from y0.algorithm.sigma_separation import ( | ||
are_sigma_separated, | ||
get_equivalence_classes, | ||
is_collider, | ||
is_non_collider_fork, | ||
is_non_collider_left_chain, | ||
is_non_collider_right_chain, | ||
is_z_sigma_open, | ||
) | ||
from y0.dsl import V1, V2, V3, V4, V5, V6, Variable | ||
from y0.graph import NxMixedGraph | ||
|
||
V7, V8 = map(Variable, ["V7", "V8"]) | ||
|
||
#: Figure 3 from https://arxiv.org/abs/1807.03024 | ||
graph = NxMixedGraph.from_edges( | ||
directed=[ | ||
(V1, V2), | ||
(V2, V3), | ||
(V3, V4), | ||
(V4, V5), | ||
(V4, V8), | ||
(V5, V2), | ||
(V6, V7), | ||
(V7, V6), | ||
], | ||
undirected=[ | ||
(V1, V2), | ||
(V4, V6), | ||
(V4, V7), | ||
(V6, V7), | ||
], | ||
) | ||
|
||
|
||
class TestSigmaSeparation(unittest.TestCase): | ||
"""Test sigma separation. | ||
These tests come from Table 1 in https://arxiv.org/abs/1807.03024. | ||
The sigma equivalence classes in Figure 3 are {v1}, {v2, v3, v4, v5}, | ||
{v6, v7}, and {v8}. | ||
""" | ||
|
||
def setUp(self) -> None: | ||
"""Set up the test case.""" | ||
self.sigma = get_equivalence_classes(graph) | ||
|
||
def test_equivalence_classes(self): | ||
"""Test getting equivalence classes.""" | ||
equivalent_classes = { | ||
frozenset([V1]), | ||
frozenset([V2, V3, V4, V5]), | ||
frozenset([V6, V7]), | ||
frozenset([V8]), | ||
} | ||
expected_equivalent_classes = {n: c for c in equivalent_classes for n in c} | ||
self.assertEqual(expected_equivalent_classes, self.sigma) | ||
|
||
def test_collider(self): | ||
"""Test checking colliders.""" | ||
self.assertTrue(is_collider(graph, left=V4, middle=V5, right=V4, conditions={V3, V5})) | ||
|
||
def test_left_chain(self): | ||
"""Test checking non-colliders (left chain).""" | ||
self.assertTrue( | ||
is_non_collider_left_chain( | ||
graph, left=V5, middle=V4, right=V6, conditions={V3, V5}, sigma=self.sigma | ||
) | ||
) | ||
|
||
def test_right_chain(self): | ||
"""Test checking non-colliders (right chain).""" | ||
self.assertTrue( | ||
is_non_collider_right_chain( | ||
graph, left=V1, middle=V2, right=V3, conditions={V3, V5}, sigma=self.sigma | ||
) | ||
) | ||
self.assertTrue( | ||
is_non_collider_right_chain( | ||
graph, left=V2, middle=V3, right=V4, conditions={V3, V5}, sigma=self.sigma | ||
) | ||
) | ||
self.assertTrue( | ||
is_non_collider_right_chain( | ||
graph, left=V3, middle=V4, right=V5, conditions={V3, V5}, sigma=self.sigma | ||
) | ||
) | ||
|
||
def test_fork(self): | ||
"""Test checking non-colliders (fork).""" | ||
self.assertTrue( | ||
is_non_collider_fork( | ||
graph, left=V5, middle=V4, right=V8, conditions={V3, V5}, sigma=self.sigma | ||
) | ||
) | ||
|
||
def test_z_sigma_open(self): | ||
"""Tests for z-sigma-open paths.""" | ||
# this is a weird example since it backtracks | ||
path = [V1, V2, V3, V4, V5, V4, V6] | ||
self.assertFalse(is_z_sigma_open(graph, path, sigma=self.sigma)) | ||
self.assertTrue(is_z_sigma_open(graph, path, conditions={V3, V5}, sigma=self.sigma)) | ||
|
||
def test_separations_figure_3(self): | ||
"""Test comparisons of d-separation and sigma-separation.""" | ||
for left, right, conditions, d, s in [ | ||
(V2, V4, [V3, V5], True, False), | ||
(V1, V6, [], True, True), | ||
(V1, V6, [V3, V5], True, False), | ||
(V1, V8, [], False, False), | ||
(V1, V8, [V3, V5], True, False), | ||
(V1, V8, [V4], True, True), | ||
]: | ||
with self.subTest(left=left, right=right, conditions=conditions): | ||
self.assertEqual( | ||
d, are_d_separated(graph, left, right, conditions=conditions).separated | ||
) | ||
self.assertEqual(s, are_sigma_separated(graph, left, right, conditions=conditions)) |