From 368d63d99938e6fa1120e5f43fa04d3ff6192e41 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Fri, 22 Jan 2021 03:32:58 +0100 Subject: [PATCH 1/3] Add outline for verma constraints References #25 --- src/y0/algorithm/verma_constraints.py | 44 +++++++++++++++++++ tests/test_algorithm/test_verma_constraint.py | 31 +++++++++++++ 2 files changed, 75 insertions(+) create mode 100644 src/y0/algorithm/verma_constraints.py create mode 100644 tests/test_algorithm/test_verma_constraint.py diff --git a/src/y0/algorithm/verma_constraints.py b/src/y0/algorithm/verma_constraints.py new file mode 100644 index 000000000..744b7fa52 --- /dev/null +++ b/src/y0/algorithm/verma_constraints.py @@ -0,0 +1,44 @@ +# -*- coding: utf-8 -*- + +"""Implementation to get Verma constraints on a graph.""" + +from __future__ import annotations + +from typing import Iterable, NamedTuple, Set, Tuple + +from ananke.graphs import ADMG + +from ..dsl import Expression + +__all__ = [ + 'VermaConstraint', + 'get_verma_constraints', +] + + +class VermaConstraint(NamedTuple): + """A Verma constraint.""" + + expression: Expression + nodes: Tuple[str, ...] + + @property + def is_canonical(self) -> bool: + """Return if the nodes are in a canonical order.""" + return tuple(sorted(self.nodes)) == self.nodes + + @classmethod + def create(cls, expression, nodes: Iterable[str]) -> VermaConstraint: + """Create a canonical Verma constraint.""" + return VermaConstraint(expression, tuple(sorted(set(nodes)))) + + +def get_verma_constraints(graph: ADMG) -> Set[VermaConstraint]: + """Get the Verma constraints on the graph. + + :param graph: An acyclic directed mixed graph + :return: A set of verma constraings, which are pairs of probability expressions and set of nodes. + + .. seealso:: Original issue https://github.com/y0-causal-inference/y0/issues/25 + """ + raise NotImplementedError diff --git a/tests/test_algorithm/test_verma_constraint.py b/tests/test_algorithm/test_verma_constraint.py new file mode 100644 index 000000000..8027e236b --- /dev/null +++ b/tests/test_algorithm/test_verma_constraint.py @@ -0,0 +1,31 @@ +# -*- coding: utf-8 -*- + +"""Test getting Verma constraints.""" + +import unittest +from typing import Set + +from y0.algorithm.verma_constraints import VermaConstraint, get_verma_constraints +from y0.graph import NxMixedGraph, napkin_graph + + +class TestVermaConstraints(unittest.TestCase): + """Test getting Verma constraints.""" + + def assert_verma_constraints(self, graph: NxMixedGraph, expected: Set[VermaConstraint]): + """Assert that the graph has the correct conditional independencies.""" + verma_constraints = get_verma_constraints(graph.to_admg()) + self.assertTrue( + all( + verma_constraint.is_canonical + for verma_constraint in verma_constraints + ), + msg='one or more of the returned VermaConstraint instances are not canonical', + ) + self.assertEqual(expected, verma_constraints) + + def test_napkin(self): + """Test getting Verma constraints on the napkin graph.""" + # TODO how is Q[Y](Y, X, V1) represented as an expression? + c1 = VermaConstraint(..., ('R',)) + self.assert_verma_constraints(napkin_graph, {c1}) From b42a0ff55a7dd52643255a0197a2ff7196fec76d Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Tue, 26 Jan 2021 21:01:51 +0100 Subject: [PATCH 2/3] Add extra note from causal fusion --- tests/test_algorithm/test_verma_constraint.py | 16 +++++++++++++--- 1 file changed, 13 insertions(+), 3 deletions(-) diff --git a/tests/test_algorithm/test_verma_constraint.py b/tests/test_algorithm/test_verma_constraint.py index 8027e236b..02ca82022 100644 --- a/tests/test_algorithm/test_verma_constraint.py +++ b/tests/test_algorithm/test_verma_constraint.py @@ -6,7 +6,11 @@ from typing import Set from y0.algorithm.verma_constraints import VermaConstraint, get_verma_constraints -from y0.graph import NxMixedGraph, napkin_graph +from y0.dsl import P, Sum, Variable, X, Y +from y0.examples import napkin +from y0.graph import NxMixedGraph + +Z1, Z2 = map(Variable, ('Z1', 'Z2')) class TestVermaConstraints(unittest.TestCase): @@ -26,6 +30,12 @@ def assert_verma_constraints(self, graph: NxMixedGraph, expected: Set[VermaConst def test_napkin(self): """Test getting Verma constraints on the napkin graph.""" - # TODO how is Q[Y](Y, X, V1) represented as an expression? + # TODO how is Q[Y](Y, X) <-> V1 represented as an expression? + # TODO it also can be represented as r1: + r1 = ( + Sum[Z2](P(Y | (Z2 | Z1, X)) * P(X | (Z1, Z2)) * P(Z2)) + / Sum[Z2](P(X | (Z2, Z1) * P(Z2))) + ) + c1 = VermaConstraint(..., ('R',)) - self.assert_verma_constraints(napkin_graph, {c1}) + self.assert_verma_constraints(napkin, {c1}) From c644a853a46ac66af456e49240ea853f8e601ff6 Mon Sep 17 00:00:00 2001 From: Charles Tapley Hoyt Date: Wed, 16 Aug 2023 18:20:58 +0200 Subject: [PATCH 3/3] Reorganize --- src/y0/algorithm/verma_constraints.py | 29 ++++++++++++++----- src/y0/causaleffect.py | 27 ----------------- src/y0/r_utils.py | 8 +++-- tests/test_algorithm/test_verma_constraint.py | 18 +++++------- tests/test_causaleffect.py | 2 +- 5 files changed, 35 insertions(+), 49 deletions(-) delete mode 100644 src/y0/causaleffect.py diff --git a/src/y0/algorithm/verma_constraints.py b/src/y0/algorithm/verma_constraints.py index 744b7fa52..62f87fc0d 100644 --- a/src/y0/algorithm/verma_constraints.py +++ b/src/y0/algorithm/verma_constraints.py @@ -4,18 +4,22 @@ from __future__ import annotations -from typing import Iterable, NamedTuple, Set, Tuple - -from ananke.graphs import ADMG +import logging +from typing import Iterable, List, NamedTuple, Set, Tuple from ..dsl import Expression +from ..graph import NxMixedGraph +from ..r_utils import uses_r __all__ = [ - 'VermaConstraint', - 'get_verma_constraints', + "VermaConstraint", + "get_verma_constraints", ] +logger = logging.getLogger(__name__) + + class VermaConstraint(NamedTuple): """A Verma constraint.""" @@ -33,12 +37,23 @@ def create(cls, expression, nodes: Iterable[str]) -> VermaConstraint: return VermaConstraint(expression, tuple(sorted(set(nodes)))) -def get_verma_constraints(graph: ADMG) -> Set[VermaConstraint]: +def get_verma_constraints(graph: NxMixedGraph) -> Set[VermaConstraint]: """Get the Verma constraints on the graph. :param graph: An acyclic directed mixed graph - :return: A set of verma constraings, which are pairs of probability expressions and set of nodes. + :return: A set of verma constraints, which are pairs of probability expressions and set of nodes. .. seealso:: Original issue https://github.com/y0-causal-inference/y0/issues/25 """ raise NotImplementedError + + +@uses_r +def r_get_verma_constraints(graph: NxMixedGraph) -> List[VermaConstraint]: + """Calculate the verma constraints on the graph using ``causaleffect``.""" + graph = graph.to_causaleffect() + + from rpy2 import robjects + + verma_constraints = robjects.r["verma.constraints"] + return [VermaConstraint.from_element(row) for row in verma_constraints(graph)] diff --git a/src/y0/causaleffect.py b/src/y0/causaleffect.py deleted file mode 100644 index 381cd86a7..000000000 --- a/src/y0/causaleffect.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Interface to the R causaleffect package via :mod:`rpy2`.""" - -from __future__ import annotations - -import logging -from typing import Sequence, Union - -from rpy2 import robjects - -from .graph import CausalEffectGraph, NxMixedGraph -from .r_utils import uses_r -from .struct import VermaConstraint - -logger = logging.getLogger(__name__) - - -@uses_r -def r_get_verma_constraints( - graph: Union[NxMixedGraph, CausalEffectGraph] -) -> Sequence[VermaConstraint]: - """Calculate the verma constraints on the graph using ``causaleffect``.""" - if isinstance(graph, NxMixedGraph): - graph = graph.to_causaleffect() - verma_constraints = robjects.r["verma.constraints"] - return [VermaConstraint.from_element(row) for row in verma_constraints(graph)] diff --git a/src/y0/r_utils.py b/src/y0/r_utils.py index ef1500ba7..34220dcf7 100644 --- a/src/y0/r_utils.py +++ b/src/y0/r_utils.py @@ -6,9 +6,6 @@ from functools import lru_cache, wraps from typing import Iterable, Tuple -from rpy2.robjects.packages import importr, isinstalled -from rpy2.robjects.vectors import StrVector - from .dsl import Variable __all__ = [ @@ -19,6 +16,7 @@ CAUSALEFFECT = "causaleffect" IGRAPH = "igraph" +#: A list of R packages that are required R_REQUIREMENTS = [ CAUSALEFFECT, IGRAPH, @@ -32,6 +30,8 @@ def prepare_renv(requirements: Iterable[str]) -> None: .. seealso:: https://rpy2.github.io/doc/v3.4.x/html/introduction.html#installing-packages """ + from rpy2.robjects.packages import importr, isinstalled + # import R's utility package utils = importr("utils") @@ -43,6 +43,8 @@ def prepare_renv(requirements: Iterable[str]) -> None: ] if uninstalled_requirements: logger.warning("installing R packages: %s", uninstalled_requirements) + from rpy2.robjects.vectors import StrVector + utils.install_packages(StrVector(uninstalled_requirements)) for requirement in requirements: diff --git a/tests/test_algorithm/test_verma_constraint.py b/tests/test_algorithm/test_verma_constraint.py index 02ca82022..0ffd26df7 100644 --- a/tests/test_algorithm/test_verma_constraint.py +++ b/tests/test_algorithm/test_verma_constraint.py @@ -10,7 +10,7 @@ from y0.examples import napkin from y0.graph import NxMixedGraph -Z1, Z2 = map(Variable, ('Z1', 'Z2')) +Z1, Z2 = map(Variable, ("Z1", "Z2")) class TestVermaConstraints(unittest.TestCase): @@ -18,13 +18,10 @@ class TestVermaConstraints(unittest.TestCase): def assert_verma_constraints(self, graph: NxMixedGraph, expected: Set[VermaConstraint]): """Assert that the graph has the correct conditional independencies.""" - verma_constraints = get_verma_constraints(graph.to_admg()) + verma_constraints = get_verma_constraints(graph) self.assertTrue( - all( - verma_constraint.is_canonical - for verma_constraint in verma_constraints - ), - msg='one or more of the returned VermaConstraint instances are not canonical', + all(verma_constraint.is_canonical for verma_constraint in verma_constraints), + msg="one or more of the returned VermaConstraint instances are not canonical", ) self.assertEqual(expected, verma_constraints) @@ -32,10 +29,9 @@ def test_napkin(self): """Test getting Verma constraints on the napkin graph.""" # TODO how is Q[Y](Y, X) <-> V1 represented as an expression? # TODO it also can be represented as r1: - r1 = ( - Sum[Z2](P(Y | (Z2 | Z1, X)) * P(X | (Z1, Z2)) * P(Z2)) - / Sum[Z2](P(X | (Z2, Z1) * P(Z2))) + r1 = Sum[Z2](P(Y | (Z2 | Z1, X)) * P(X | (Z1, Z2)) * P(Z2)) / Sum[Z2]( + P(X | (Z2, Z1) * P(Z2)) ) - c1 = VermaConstraint(..., ('R',)) + c1 = VermaConstraint(..., ("R",)) self.assert_verma_constraints(napkin, {c1}) diff --git a/tests/test_causaleffect.py b/tests/test_causaleffect.py index 9ffa33860..6fcf95d2b 100644 --- a/tests/test_causaleffect.py +++ b/tests/test_causaleffect.py @@ -10,7 +10,7 @@ from y0.examples import examples, verma_1 try: - from y0.causaleffect import r_get_verma_constraints + from y0.algorithm.verma_constraints import r_get_verma_constraints from y0.r_utils import CAUSALEFFECT, IGRAPH from y0.struct import VermaConstraint except ImportError: # rpy2 is not installed