From 92df01086da5f7c053704e5b1534d242f4653c95 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 29 Aug 2023 18:13:34 +0200 Subject: [PATCH] Implement sigma-separation (#150) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Closes #120 This PR adds a high-level interface and implements tests for sigma-separation, a generalization of d-separation that works not only for directed acyclic graphs, but also for directed graphs containing cycles. It was originally introduced in > Constraint-based Causal Discovery for Non-Linear Structural Causal Models with Cycles and Latent Confounders > Forré and Mooij. 2019. [arXiv:1807.03024](https://arxiv.org/abs/1807.03024) and is an integral part of cyclic ID algorithm (see https://github.com/y0-causal-inference/y0/issues/71) and the gID algorithm (see #72) ## References/Notes - https://stats.stackexchange.com/questions/586810/sigma-separation-question-in-cyclic-causal-graph-understanding-sigma-separatio - Author's implementation: https://github.com/caus-am/sigmasep --- src/y0/algorithm/sigma_separation.py | 271 ++++++++++++++++++ tests/test_algorithm/test_sigma_separation.py | 123 ++++++++ 2 files changed, 394 insertions(+) create mode 100644 src/y0/algorithm/sigma_separation.py create mode 100644 tests/test_algorithm/test_sigma_separation.py diff --git a/src/y0/algorithm/sigma_separation.py b/src/y0/algorithm/sigma_separation.py new file mode 100644 index 000000000..c673ec06f --- /dev/null +++ b/src/y0/algorithm/sigma_separation.py @@ -0,0 +1,271 @@ +"""Implementation of sigma-separation.""" + +from typing import Iterable, Optional, Sequence + +import networkx as nx +from more_itertools import triplewise + +from y0.dsl import Variable +from y0.graph import NxMixedGraph + +__all__ = [ + "are_sigma_separated", + "is_z_sigma_open", + "get_equivalence_classes", +] + + +def are_sigma_separated( + graph: NxMixedGraph, + left: Variable, + right: Variable, + *, + conditions: Optional[Iterable[Variable]] = None, + cutoff: Optional[int] = None, +) -> bool: + """Test if two variables are sigma-separated. + + Sigma separation is a generalization of d-separation that + works not only for directed acyclic graphs, but also for + directed graphs containing cycles. It was originally introduced + in https://arxiv.org/abs/1807.03024. + + We say that X and Y are σ-connected by Z or not + σ-separated by Z if there exists a path π (with some + n ≥ 1 nodes) in G with one endnode in X and + one endnode in Y that is Z-σ-open. σ-separated is the + opposite of σ-connected (logical not). + + :param graph: Graph to test + :param left: A node in the graph + :param right: A node in the graph + :param conditions: A collection of graph nodes + :param cutoff: The maximum path length to check. By default, is unbounded. + :return: If a and b are sigma-separated. + """ + if conditions is None: + conditions = set() + else: + conditions = set(conditions) + + sigma = get_equivalence_classes(graph) + return not any( + is_z_sigma_open(graph, path, conditions=conditions, sigma=sigma) + # Technically, this algorithm should generate all paths, which could include + # repeat visits to nodes and edges, but this is computationally intractable, + # so the is_z_sigma_open() subroutine contains a novel path augmentation + # algorithm. This might not be officially complete. + for path in nx.all_simple_paths(graph.disorient(), left, right, cutoff=cutoff) + ) + + +def is_z_sigma_open( + graph: NxMixedGraph, + path: Sequence[Variable], + *, + sigma: dict[Variable, set[Variable]], + conditions: Optional[set[Variable]] = None, +) -> bool: + r"""Check if a path is Z-sigma-open. + + :param graph: A mixed graph + :param path: A path in the graph. Denoted as $\pi$ in the paper. The + node in position $i$ in the path is denoted with $v_i$. + :param conditions : A set of nodes chosen as conditions, denoted by $Z$ in the paper + :param sigma: The set of equivalence classes. Can be calculated with + :func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. + :returns: If the path is Z-sigma-open + + A path is $Z-\sigma-\text{open}$ if: + + 1. The end nodes $v_1, v_n \notin Z$ + 2. Every triple of adjacent nodes in the path is of the form: + 1. Collider (:func:`is_collider`) + 2. (non-collider) left chain (:func:`is_non_collider_left_chain`) + 3. (non-collider) right chain (:func:`is_non_collider_left_chain`) + 4. (non-collider) fork (:func:`is_non_collider_fork`) + 5. (non-collider) with undirected edge (:func:`is_non_collider_undirected`, not implemented) + """ + if conditions is None: + conditions = set() + if path[0] in conditions or path[-1] in conditions: + return False + return all( + _triple_has_correct_form(graph, left, middle, right, conditions, sigma) + for left, middle, right in triplewise(path) + ) + + +def _triple_has_correct_form( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], + sigma: dict[Variable, set[Variable]], +) -> bool: + if _triple_helper(graph, left, middle, right, conditions, sigma): + return True + # augment with backtracks, since you're allowed to go back (just like Season 5 of Lost). + # this is a better solution than generating infinite paths, but might still be mathematically + # incomplete. In this setup, 𝑣3→𝑣4↔𝑣6 becomes 𝑣3→𝑣4→𝑣5←𝑣4↔𝑣6 to get some sweet backtrack paths + # through the middle node to a neighbor and then back before going to the right node. + neighbors = {n for n in graph.disorient().neighbors(middle) if n != middle} + for neighbor in neighbors: + if ( + _triple_helper(graph, left, middle, neighbor, conditions, sigma) + and _triple_helper(graph, middle, neighbor, middle, conditions, sigma) + and _triple_helper(graph, neighbor, middle, right, conditions, sigma) + ): + return True + return False + + +def _triple_helper( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], + sigma: dict[Variable, set[Variable]], +) -> bool: + return ( + is_collider(graph, left, middle, right, conditions) + or is_non_collider_left_chain(graph, left, middle, right, conditions, sigma) + or is_non_collider_right_chain(graph, left, middle, right, conditions, sigma) + or is_non_collider_fork(graph, left, middle, right, conditions, sigma) + ) + + +def _has_either_edge(graph: NxMixedGraph, u, v) -> bool: + return graph.directed.has_edge(u, v) or graph.undirected.has_edge(u, v) + + +def _only_directed_edge(graph, u, v) -> bool: + return graph.directed.has_edge(u, v) and not graph.undirected.has_edge(u, v) + + +def is_collider( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], +) -> bool: + """Check if three nodes form a collider under the given conditions. + + :param graph: A mixed graph + :param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper + :param middle: The second node in the subsequence, denoted as $v_i$ in the paper + :param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper + :param conditions: The conditional variables, denoted as $Z$ in the paper + :return: If the three nodes form a collider + """ + return ( + _has_either_edge(graph, left, middle) + and _has_either_edge(graph, right, middle) + and middle in conditions + ) + + +def is_non_collider_left_chain( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], + sigma: dict[Variable, set[Variable]], +) -> bool: + r"""Check if three nodes form a non-collider (left chain) given the conditions. + + :param graph: A mixed graph + :param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper + :param middle: The second node in the subsequence, denoted as $v_i$ in the paper + :param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper + :param conditions: The conditional variables, denoted as $Z$ in the paper + :param sigma: The set of equivalence classes. Can be calculated with + :func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. + :return: If the three nodes form a non-collider (left chain) given the conditions. + """ + return ( + _only_directed_edge(graph, middle, left) + and _has_either_edge(graph, right, middle) + and (middle not in conditions or middle in conditions.intersection(sigma[left])) + ) + + +def is_non_collider_right_chain( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], + sigma: dict[Variable, set[Variable]], +) -> bool: + r"""Check if three nodes form a non-collider (right chain) given the conditions. + + :param graph: A mixed graph + :param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper + :param middle: The second node in the subsequence, denoted as $v_i$ in the paper + :param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper + :param conditions: The conditional variables, denoted as $Z$ in the paper + :param sigma: The set of equivalence classes. Can be calculated with + :func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. + :return: If the three nodes form a non-collider (right chain) given the conditions. + """ + return ( + _has_either_edge(graph, left, middle) + and _only_directed_edge(graph, middle, right) + and (middle not in conditions or middle in conditions.intersection(sigma[right])) + ) + + +def is_non_collider_fork( + graph: NxMixedGraph, + left: Variable, + middle: Variable, + right: Variable, + conditions: set[Variable], + sigma: dict[Variable, set[Variable]], +) -> bool: + r"""Check if three nodes form a non-collider (fork) given the conditions. + + :param graph: A mixed graph + :param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper + :param middle: The second node in the subsequence, denoted as $v_i$ in the paper + :param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper + :param conditions: The conditional variables, denoted as $Z$ in the paper + :param sigma: The set of equivalence classes. Can be calculated with + :func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper. + :return: If the three nodes form a non-collider (fork) given the conditions. + """ + a = _only_directed_edge(graph, middle, left) + b = _only_directed_edge(graph, middle, right) + c = middle not in conditions + d = middle in conditions.intersection(sigma[left]).intersection(sigma[right]) + return a and b and (c or d) + + +def get_equivalence_classes(graph: NxMixedGraph) -> dict[Variable, set[Variable]]: + """Get equivalence classes. + + :param graph: A mixed graph + :returns: A mapping from variables to their equivalence class, + defined as the second option from the paper (see below) + + 1. The finest/trivial σ-CG structure of + a mixed graph G is given by σ(v) := {v} for all + v ∈ V . In this way σ-separation in G coincides with + the usual notion of d-separation in a d-connection + graph (d-CG) G (see [19]). We will take this as the + definition of d-separation and d-CG in the following. + 2. The coarsest σ-CG structure of a mixed graph G is + given by σ(v) := ScG(v) := AncG(v) ∩ DescG(v) + w.r.t. the underlying directed graph. Note that the + definition of strongly connected component totally + ignores the bi- and undirected edges of the σ-CG. + """ + return { + node: graph.ancestors_inclusive(node).intersection(graph.descendants_inclusive(node)) + for node in graph.nodes() + } diff --git a/tests/test_algorithm/test_sigma_separation.py b/tests/test_algorithm/test_sigma_separation.py new file mode 100644 index 000000000..6dba6c401 --- /dev/null +++ b/tests/test_algorithm/test_sigma_separation.py @@ -0,0 +1,123 @@ +"""Test sigma separation.""" + +import unittest + +from y0.algorithm.conditional_independencies import are_d_separated +from y0.algorithm.sigma_separation import ( + are_sigma_separated, + get_equivalence_classes, + is_collider, + is_non_collider_fork, + is_non_collider_left_chain, + is_non_collider_right_chain, + is_z_sigma_open, +) +from y0.dsl import V1, V2, V3, V4, V5, V6, Variable +from y0.graph import NxMixedGraph + +V7, V8 = map(Variable, ["V7", "V8"]) + +#: Figure 3 from https://arxiv.org/abs/1807.03024 +graph = NxMixedGraph.from_edges( + directed=[ + (V1, V2), + (V2, V3), + (V3, V4), + (V4, V5), + (V4, V8), + (V5, V2), + (V6, V7), + (V7, V6), + ], + undirected=[ + (V1, V2), + (V4, V6), + (V4, V7), + (V6, V7), + ], +) + + +class TestSigmaSeparation(unittest.TestCase): + """Test sigma separation. + + These tests come from Table 1 in https://arxiv.org/abs/1807.03024. + The sigma equivalence classes in Figure 3 are {v1}, {v2, v3, v4, v5}, + {v6, v7}, and {v8}. + """ + + def setUp(self) -> None: + """Set up the test case.""" + self.sigma = get_equivalence_classes(graph) + + def test_equivalence_classes(self): + """Test getting equivalence classes.""" + equivalent_classes = { + frozenset([V1]), + frozenset([V2, V3, V4, V5]), + frozenset([V6, V7]), + frozenset([V8]), + } + expected_equivalent_classes = {n: c for c in equivalent_classes for n in c} + self.assertEqual(expected_equivalent_classes, self.sigma) + + def test_collider(self): + """Test checking colliders.""" + self.assertTrue(is_collider(graph, left=V4, middle=V5, right=V4, conditions={V3, V5})) + + def test_left_chain(self): + """Test checking non-colliders (left chain).""" + self.assertTrue( + is_non_collider_left_chain( + graph, left=V5, middle=V4, right=V6, conditions={V3, V5}, sigma=self.sigma + ) + ) + + def test_right_chain(self): + """Test checking non-colliders (right chain).""" + self.assertTrue( + is_non_collider_right_chain( + graph, left=V1, middle=V2, right=V3, conditions={V3, V5}, sigma=self.sigma + ) + ) + self.assertTrue( + is_non_collider_right_chain( + graph, left=V2, middle=V3, right=V4, conditions={V3, V5}, sigma=self.sigma + ) + ) + self.assertTrue( + is_non_collider_right_chain( + graph, left=V3, middle=V4, right=V5, conditions={V3, V5}, sigma=self.sigma + ) + ) + + def test_fork(self): + """Test checking non-colliders (fork).""" + self.assertTrue( + is_non_collider_fork( + graph, left=V5, middle=V4, right=V8, conditions={V3, V5}, sigma=self.sigma + ) + ) + + def test_z_sigma_open(self): + """Tests for z-sigma-open paths.""" + # this is a weird example since it backtracks + path = [V1, V2, V3, V4, V5, V4, V6] + self.assertFalse(is_z_sigma_open(graph, path, sigma=self.sigma)) + self.assertTrue(is_z_sigma_open(graph, path, conditions={V3, V5}, sigma=self.sigma)) + + def test_separations_figure_3(self): + """Test comparisons of d-separation and sigma-separation.""" + for left, right, conditions, d, s in [ + (V2, V4, [V3, V5], True, False), + (V1, V6, [], True, True), + (V1, V6, [V3, V5], True, False), + (V1, V8, [], False, False), + (V1, V8, [V3, V5], True, False), + (V1, V8, [V4], True, True), + ]: + with self.subTest(left=left, right=right, conditions=conditions): + self.assertEqual( + d, are_d_separated(graph, left, right, conditions=conditions).separated + ) + self.assertEqual(s, are_sigma_separated(graph, left, right, conditions=conditions))