diff --git a/src/y0/algorithm/verma_constraints.py b/src/y0/algorithm/verma_constraints.py new file mode 100644 index 000000000..62f87fc0d --- /dev/null +++ b/src/y0/algorithm/verma_constraints.py @@ -0,0 +1,59 @@ +# -*- coding: utf-8 -*- + +"""Implementation to get Verma constraints on a graph.""" + +from __future__ import annotations + +import logging +from typing import Iterable, List, NamedTuple, Set, Tuple + +from ..dsl import Expression +from ..graph import NxMixedGraph +from ..r_utils import uses_r + +__all__ = [ + "VermaConstraint", + "get_verma_constraints", +] + + +logger = logging.getLogger(__name__) + + +class VermaConstraint(NamedTuple): + """A Verma constraint.""" + + expression: Expression + nodes: Tuple[str, ...] + + @property + def is_canonical(self) -> bool: + """Return if the nodes are in a canonical order.""" + return tuple(sorted(self.nodes)) == self.nodes + + @classmethod + def create(cls, expression, nodes: Iterable[str]) -> VermaConstraint: + """Create a canonical Verma constraint.""" + return VermaConstraint(expression, tuple(sorted(set(nodes)))) + + +def get_verma_constraints(graph: NxMixedGraph) -> Set[VermaConstraint]: + """Get the Verma constraints on the graph. + + :param graph: An acyclic directed mixed graph + :return: A set of verma constraints, which are pairs of probability expressions and set of nodes. + + .. seealso:: Original issue https://github.com/y0-causal-inference/y0/issues/25 + """ + raise NotImplementedError + + +@uses_r +def r_get_verma_constraints(graph: NxMixedGraph) -> List[VermaConstraint]: + """Calculate the verma constraints on the graph using ``causaleffect``.""" + graph = graph.to_causaleffect() + + from rpy2 import robjects + + verma_constraints = robjects.r["verma.constraints"] + return [VermaConstraint.from_element(row) for row in verma_constraints(graph)] diff --git a/src/y0/causaleffect.py b/src/y0/causaleffect.py deleted file mode 100644 index 381cd86a7..000000000 --- a/src/y0/causaleffect.py +++ /dev/null @@ -1,27 +0,0 @@ -# -*- coding: utf-8 -*- - -"""Interface to the R causaleffect package via :mod:`rpy2`.""" - -from __future__ import annotations - -import logging -from typing import Sequence, Union - -from rpy2 import robjects - -from .graph import CausalEffectGraph, NxMixedGraph -from .r_utils import uses_r -from .struct import VermaConstraint - -logger = logging.getLogger(__name__) - - -@uses_r -def r_get_verma_constraints( - graph: Union[NxMixedGraph, CausalEffectGraph] -) -> Sequence[VermaConstraint]: - """Calculate the verma constraints on the graph using ``causaleffect``.""" - if isinstance(graph, NxMixedGraph): - graph = graph.to_causaleffect() - verma_constraints = robjects.r["verma.constraints"] - return [VermaConstraint.from_element(row) for row in verma_constraints(graph)] diff --git a/src/y0/r_utils.py b/src/y0/r_utils.py index ef1500ba7..34220dcf7 100644 --- a/src/y0/r_utils.py +++ b/src/y0/r_utils.py @@ -6,9 +6,6 @@ from functools import lru_cache, wraps from typing import Iterable, Tuple -from rpy2.robjects.packages import importr, isinstalled -from rpy2.robjects.vectors import StrVector - from .dsl import Variable __all__ = [ @@ -19,6 +16,7 @@ CAUSALEFFECT = "causaleffect" IGRAPH = "igraph" +#: A list of R packages that are required R_REQUIREMENTS = [ CAUSALEFFECT, IGRAPH, @@ -32,6 +30,8 @@ def prepare_renv(requirements: Iterable[str]) -> None: .. seealso:: https://rpy2.github.io/doc/v3.4.x/html/introduction.html#installing-packages """ + from rpy2.robjects.packages import importr, isinstalled + # import R's utility package utils = importr("utils") @@ -43,6 +43,8 @@ def prepare_renv(requirements: Iterable[str]) -> None: ] if uninstalled_requirements: logger.warning("installing R packages: %s", uninstalled_requirements) + from rpy2.robjects.vectors import StrVector + utils.install_packages(StrVector(uninstalled_requirements)) for requirement in requirements: diff --git a/tests/test_algorithm/test_verma_constraint.py b/tests/test_algorithm/test_verma_constraint.py new file mode 100644 index 000000000..0ffd26df7 --- /dev/null +++ b/tests/test_algorithm/test_verma_constraint.py @@ -0,0 +1,37 @@ +# -*- coding: utf-8 -*- + +"""Test getting Verma constraints.""" + +import unittest +from typing import Set + +from y0.algorithm.verma_constraints import VermaConstraint, get_verma_constraints +from y0.dsl import P, Sum, Variable, X, Y +from y0.examples import napkin +from y0.graph import NxMixedGraph + +Z1, Z2 = map(Variable, ("Z1", "Z2")) + + +class TestVermaConstraints(unittest.TestCase): + """Test getting Verma constraints.""" + + def assert_verma_constraints(self, graph: NxMixedGraph, expected: Set[VermaConstraint]): + """Assert that the graph has the correct conditional independencies.""" + verma_constraints = get_verma_constraints(graph) + self.assertTrue( + all(verma_constraint.is_canonical for verma_constraint in verma_constraints), + msg="one or more of the returned VermaConstraint instances are not canonical", + ) + self.assertEqual(expected, verma_constraints) + + def test_napkin(self): + """Test getting Verma constraints on the napkin graph.""" + # TODO how is Q[Y](Y, X) <-> V1 represented as an expression? + # TODO it also can be represented as r1: + r1 = Sum[Z2](P(Y | (Z2 | Z1, X)) * P(X | (Z1, Z2)) * P(Z2)) / Sum[Z2]( + P(X | (Z2, Z1) * P(Z2)) + ) + + c1 = VermaConstraint(..., ("R",)) + self.assert_verma_constraints(napkin, {c1}) diff --git a/tests/test_causaleffect.py b/tests/test_causaleffect.py index 9ffa33860..6fcf95d2b 100644 --- a/tests/test_causaleffect.py +++ b/tests/test_causaleffect.py @@ -10,7 +10,7 @@ from y0.examples import examples, verma_1 try: - from y0.causaleffect import r_get_verma_constraints + from y0.algorithm.verma_constraints import r_get_verma_constraints from y0.r_utils import CAUSALEFFECT, IGRAPH from y0.struct import VermaConstraint except ImportError: # rpy2 is not installed