Skip to content

Commit

Permalink
Implement sigma-separation (#150)
Browse files Browse the repository at this point in the history
Closes #120

This PR adds a high-level interface and implements tests for
sigma-separation, a generalization of d-separation that works not only
for directed acyclic graphs, but also for directed graphs containing
cycles. It was originally introduced in

> Constraint-based Causal Discovery for Non-Linear Structural Causal
Models with Cycles and Latent Confounders
> Forré and Mooij. 2019.
[arXiv:1807.03024](https://arxiv.org/abs/1807.03024)

and is an integral part of cyclic ID algorithm (see
#71) and the gID
algorithm (see #72)

## References/Notes

-
https://stats.stackexchange.com/questions/586810/sigma-separation-question-in-cyclic-causal-graph-understanding-sigma-separatio
- Author's implementation: https://github.com/caus-am/sigmasep
  • Loading branch information
cthoyt authored Aug 29, 2023
1 parent fa770da commit 92df010
Show file tree
Hide file tree
Showing 2 changed files with 394 additions and 0 deletions.
271 changes: 271 additions & 0 deletions src/y0/algorithm/sigma_separation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,271 @@
"""Implementation of sigma-separation."""

from typing import Iterable, Optional, Sequence

import networkx as nx
from more_itertools import triplewise

from y0.dsl import Variable
from y0.graph import NxMixedGraph

__all__ = [
"are_sigma_separated",
"is_z_sigma_open",
"get_equivalence_classes",
]


def are_sigma_separated(
graph: NxMixedGraph,
left: Variable,
right: Variable,
*,
conditions: Optional[Iterable[Variable]] = None,
cutoff: Optional[int] = None,
) -> bool:
"""Test if two variables are sigma-separated.
Sigma separation is a generalization of d-separation that
works not only for directed acyclic graphs, but also for
directed graphs containing cycles. It was originally introduced
in https://arxiv.org/abs/1807.03024.
We say that X and Y are σ-connected by Z or not
σ-separated by Z if there exists a path π (with some
n ≥ 1 nodes) in G with one endnode in X and
one endnode in Y that is Z-σ-open. σ-separated is the
opposite of σ-connected (logical not).
:param graph: Graph to test
:param left: A node in the graph
:param right: A node in the graph
:param conditions: A collection of graph nodes
:param cutoff: The maximum path length to check. By default, is unbounded.
:return: If a and b are sigma-separated.
"""
if conditions is None:
conditions = set()
else:
conditions = set(conditions)

sigma = get_equivalence_classes(graph)
return not any(
is_z_sigma_open(graph, path, conditions=conditions, sigma=sigma)
# Technically, this algorithm should generate all paths, which could include
# repeat visits to nodes and edges, but this is computationally intractable,
# so the is_z_sigma_open() subroutine contains a novel path augmentation
# algorithm. This might not be officially complete.
for path in nx.all_simple_paths(graph.disorient(), left, right, cutoff=cutoff)
)


def is_z_sigma_open(
graph: NxMixedGraph,
path: Sequence[Variable],
*,
sigma: dict[Variable, set[Variable]],
conditions: Optional[set[Variable]] = None,
) -> bool:
r"""Check if a path is Z-sigma-open.
:param graph: A mixed graph
:param path: A path in the graph. Denoted as $\pi$ in the paper. The
node in position $i$ in the path is denoted with $v_i$.
:param conditions : A set of nodes chosen as conditions, denoted by $Z$ in the paper
:param sigma: The set of equivalence classes. Can be calculated with
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper.
:returns: If the path is Z-sigma-open
A path is $Z-\sigma-\text{open}$ if:
1. The end nodes $v_1, v_n \notin Z$
2. Every triple of adjacent nodes in the path is of the form:
1. Collider (:func:`is_collider`)
2. (non-collider) left chain (:func:`is_non_collider_left_chain`)
3. (non-collider) right chain (:func:`is_non_collider_left_chain`)
4. (non-collider) fork (:func:`is_non_collider_fork`)
5. (non-collider) with undirected edge (:func:`is_non_collider_undirected`, not implemented)
"""
if conditions is None:
conditions = set()
if path[0] in conditions or path[-1] in conditions:
return False
return all(
_triple_has_correct_form(graph, left, middle, right, conditions, sigma)
for left, middle, right in triplewise(path)
)


def _triple_has_correct_form(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
sigma: dict[Variable, set[Variable]],
) -> bool:
if _triple_helper(graph, left, middle, right, conditions, sigma):
return True
# augment with backtracks, since you're allowed to go back (just like Season 5 of Lost).
# this is a better solution than generating infinite paths, but might still be mathematically
# incomplete. In this setup, 𝑣3→𝑣4↔𝑣6 becomes 𝑣3→𝑣4→𝑣5←𝑣4↔𝑣6 to get some sweet backtrack paths
# through the middle node to a neighbor and then back before going to the right node.
neighbors = {n for n in graph.disorient().neighbors(middle) if n != middle}
for neighbor in neighbors:
if (
_triple_helper(graph, left, middle, neighbor, conditions, sigma)
and _triple_helper(graph, middle, neighbor, middle, conditions, sigma)
and _triple_helper(graph, neighbor, middle, right, conditions, sigma)
):
return True
return False


def _triple_helper(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
sigma: dict[Variable, set[Variable]],
) -> bool:
return (
is_collider(graph, left, middle, right, conditions)
or is_non_collider_left_chain(graph, left, middle, right, conditions, sigma)
or is_non_collider_right_chain(graph, left, middle, right, conditions, sigma)
or is_non_collider_fork(graph, left, middle, right, conditions, sigma)
)


def _has_either_edge(graph: NxMixedGraph, u, v) -> bool:
return graph.directed.has_edge(u, v) or graph.undirected.has_edge(u, v)


def _only_directed_edge(graph, u, v) -> bool:
return graph.directed.has_edge(u, v) and not graph.undirected.has_edge(u, v)


def is_collider(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
) -> bool:
"""Check if three nodes form a collider under the given conditions.
:param graph: A mixed graph
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper
:param conditions: The conditional variables, denoted as $Z$ in the paper
:return: If the three nodes form a collider
"""
return (
_has_either_edge(graph, left, middle)
and _has_either_edge(graph, right, middle)
and middle in conditions
)


def is_non_collider_left_chain(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
sigma: dict[Variable, set[Variable]],
) -> bool:
r"""Check if three nodes form a non-collider (left chain) given the conditions.
:param graph: A mixed graph
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper
:param conditions: The conditional variables, denoted as $Z$ in the paper
:param sigma: The set of equivalence classes. Can be calculated with
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper.
:return: If the three nodes form a non-collider (left chain) given the conditions.
"""
return (
_only_directed_edge(graph, middle, left)
and _has_either_edge(graph, right, middle)
and (middle not in conditions or middle in conditions.intersection(sigma[left]))
)


def is_non_collider_right_chain(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
sigma: dict[Variable, set[Variable]],
) -> bool:
r"""Check if three nodes form a non-collider (right chain) given the conditions.
:param graph: A mixed graph
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper
:param conditions: The conditional variables, denoted as $Z$ in the paper
:param sigma: The set of equivalence classes. Can be calculated with
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper.
:return: If the three nodes form a non-collider (right chain) given the conditions.
"""
return (
_has_either_edge(graph, left, middle)
and _only_directed_edge(graph, middle, right)
and (middle not in conditions or middle in conditions.intersection(sigma[right]))
)


def is_non_collider_fork(
graph: NxMixedGraph,
left: Variable,
middle: Variable,
right: Variable,
conditions: set[Variable],
sigma: dict[Variable, set[Variable]],
) -> bool:
r"""Check if three nodes form a non-collider (fork) given the conditions.
:param graph: A mixed graph
:param left: The first node in the subsequence, denoted as $v_{i-1}$ in the paper
:param middle: The second node in the subsequence, denoted as $v_i$ in the paper
:param right: The third node in the subsequence, denoted as $v_{i+1}$ in the paper
:param conditions: The conditional variables, denoted as $Z$ in the paper
:param sigma: The set of equivalence classes. Can be calculated with
:func:`get_equivalence_classes`, denoted by $\sigma(v)$ in the paper.
:return: If the three nodes form a non-collider (fork) given the conditions.
"""
a = _only_directed_edge(graph, middle, left)
b = _only_directed_edge(graph, middle, right)
c = middle not in conditions
d = middle in conditions.intersection(sigma[left]).intersection(sigma[right])
return a and b and (c or d)


def get_equivalence_classes(graph: NxMixedGraph) -> dict[Variable, set[Variable]]:
"""Get equivalence classes.
:param graph: A mixed graph
:returns: A mapping from variables to their equivalence class,
defined as the second option from the paper (see below)
1. The finest/trivial σ-CG structure of
a mixed graph G is given by σ(v) := {v} for all
v ∈ V . In this way σ-separation in G coincides with
the usual notion of d-separation in a d-connection
graph (d-CG) G (see [19]). We will take this as the
definition of d-separation and d-CG in the following.
2. The coarsest σ-CG structure of a mixed graph G is
given by σ(v) := ScG(v) := AncG(v) ∩ DescG(v)
w.r.t. the underlying directed graph. Note that the
definition of strongly connected component totally
ignores the bi- and undirected edges of the σ-CG.
"""
return {
node: graph.ancestors_inclusive(node).intersection(graph.descendants_inclusive(node))
for node in graph.nodes()
}
123 changes: 123 additions & 0 deletions tests/test_algorithm/test_sigma_separation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,123 @@
"""Test sigma separation."""

import unittest

from y0.algorithm.conditional_independencies import are_d_separated
from y0.algorithm.sigma_separation import (
are_sigma_separated,
get_equivalence_classes,
is_collider,
is_non_collider_fork,
is_non_collider_left_chain,
is_non_collider_right_chain,
is_z_sigma_open,
)
from y0.dsl import V1, V2, V3, V4, V5, V6, Variable
from y0.graph import NxMixedGraph

V7, V8 = map(Variable, ["V7", "V8"])

#: Figure 3 from https://arxiv.org/abs/1807.03024
graph = NxMixedGraph.from_edges(
directed=[
(V1, V2),
(V2, V3),
(V3, V4),
(V4, V5),
(V4, V8),
(V5, V2),
(V6, V7),
(V7, V6),
],
undirected=[
(V1, V2),
(V4, V6),
(V4, V7),
(V6, V7),
],
)


class TestSigmaSeparation(unittest.TestCase):
"""Test sigma separation.
These tests come from Table 1 in https://arxiv.org/abs/1807.03024.
The sigma equivalence classes in Figure 3 are {v1}, {v2, v3, v4, v5},
{v6, v7}, and {v8}.
"""

def setUp(self) -> None:
"""Set up the test case."""
self.sigma = get_equivalence_classes(graph)

def test_equivalence_classes(self):
"""Test getting equivalence classes."""
equivalent_classes = {
frozenset([V1]),
frozenset([V2, V3, V4, V5]),
frozenset([V6, V7]),
frozenset([V8]),
}
expected_equivalent_classes = {n: c for c in equivalent_classes for n in c}
self.assertEqual(expected_equivalent_classes, self.sigma)

def test_collider(self):
"""Test checking colliders."""
self.assertTrue(is_collider(graph, left=V4, middle=V5, right=V4, conditions={V3, V5}))

def test_left_chain(self):
"""Test checking non-colliders (left chain)."""
self.assertTrue(
is_non_collider_left_chain(
graph, left=V5, middle=V4, right=V6, conditions={V3, V5}, sigma=self.sigma
)
)

def test_right_chain(self):
"""Test checking non-colliders (right chain)."""
self.assertTrue(
is_non_collider_right_chain(
graph, left=V1, middle=V2, right=V3, conditions={V3, V5}, sigma=self.sigma
)
)
self.assertTrue(
is_non_collider_right_chain(
graph, left=V2, middle=V3, right=V4, conditions={V3, V5}, sigma=self.sigma
)
)
self.assertTrue(
is_non_collider_right_chain(
graph, left=V3, middle=V4, right=V5, conditions={V3, V5}, sigma=self.sigma
)
)

def test_fork(self):
"""Test checking non-colliders (fork)."""
self.assertTrue(
is_non_collider_fork(
graph, left=V5, middle=V4, right=V8, conditions={V3, V5}, sigma=self.sigma
)
)

def test_z_sigma_open(self):
"""Tests for z-sigma-open paths."""
# this is a weird example since it backtracks
path = [V1, V2, V3, V4, V5, V4, V6]
self.assertFalse(is_z_sigma_open(graph, path, sigma=self.sigma))
self.assertTrue(is_z_sigma_open(graph, path, conditions={V3, V5}, sigma=self.sigma))

def test_separations_figure_3(self):
"""Test comparisons of d-separation and sigma-separation."""
for left, right, conditions, d, s in [
(V2, V4, [V3, V5], True, False),
(V1, V6, [], True, True),
(V1, V6, [V3, V5], True, False),
(V1, V8, [], False, False),
(V1, V8, [V3, V5], True, False),
(V1, V8, [V4], True, True),
]:
with self.subTest(left=left, right=right, conditions=conditions):
self.assertEqual(
d, are_d_separated(graph, left, right, conditions=conditions).separated
)
self.assertEqual(s, are_sigma_separated(graph, left, right, conditions=conditions))

0 comments on commit 92df010

Please sign in to comment.