Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implementation to calculate Verma constraints #30

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
59 changes: 59 additions & 0 deletions src/y0/algorithm/verma_constraints.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-

"""Implementation to get Verma constraints on a graph."""

from __future__ import annotations

import logging
from typing import Iterable, List, NamedTuple, Set, Tuple

from ..dsl import Expression
from ..graph import NxMixedGraph
from ..r_utils import uses_r

__all__ = [
"VermaConstraint",
"get_verma_constraints",
]


logger = logging.getLogger(__name__)


class VermaConstraint(NamedTuple):
"""A Verma constraint."""

expression: Expression
nodes: Tuple[str, ...]

@property
def is_canonical(self) -> bool:
"""Return if the nodes are in a canonical order."""
return tuple(sorted(self.nodes)) == self.nodes

@classmethod
def create(cls, expression, nodes: Iterable[str]) -> VermaConstraint:
"""Create a canonical Verma constraint."""
return VermaConstraint(expression, tuple(sorted(set(nodes))))


def get_verma_constraints(graph: NxMixedGraph) -> Set[VermaConstraint]:
"""Get the Verma constraints on the graph.

:param graph: An acyclic directed mixed graph
:return: A set of verma constraints, which are pairs of probability expressions and set of nodes.

.. seealso:: Original issue https://github.com/y0-causal-inference/y0/issues/25
"""
raise NotImplementedError


@uses_r
def r_get_verma_constraints(graph: NxMixedGraph) -> List[VermaConstraint]:
"""Calculate the verma constraints on the graph using ``causaleffect``."""
graph = graph.to_causaleffect()

from rpy2 import robjects

verma_constraints = robjects.r["verma.constraints"]
return [VermaConstraint.from_element(row) for row in verma_constraints(graph)]
27 changes: 0 additions & 27 deletions src/y0/causaleffect.py

This file was deleted.

8 changes: 5 additions & 3 deletions src/y0/r_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,9 +6,6 @@
from functools import lru_cache, wraps
from typing import Iterable, Tuple

from rpy2.robjects.packages import importr, isinstalled
from rpy2.robjects.vectors import StrVector

from .dsl import Variable

__all__ = [
Expand All @@ -19,6 +16,7 @@

CAUSALEFFECT = "causaleffect"
IGRAPH = "igraph"
#: A list of R packages that are required
R_REQUIREMENTS = [
CAUSALEFFECT,
IGRAPH,
Expand All @@ -32,6 +30,8 @@ def prepare_renv(requirements: Iterable[str]) -> None:

.. seealso:: https://rpy2.github.io/doc/v3.4.x/html/introduction.html#installing-packages
"""
from rpy2.robjects.packages import importr, isinstalled

# import R's utility package
utils = importr("utils")

Expand All @@ -43,6 +43,8 @@ def prepare_renv(requirements: Iterable[str]) -> None:
]
if uninstalled_requirements:
logger.warning("installing R packages: %s", uninstalled_requirements)
from rpy2.robjects.vectors import StrVector

utils.install_packages(StrVector(uninstalled_requirements))

for requirement in requirements:
Expand Down
37 changes: 37 additions & 0 deletions tests/test_algorithm/test_verma_constraint.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
# -*- coding: utf-8 -*-

"""Test getting Verma constraints."""

import unittest
from typing import Set

from y0.algorithm.verma_constraints import VermaConstraint, get_verma_constraints
from y0.dsl import P, Sum, Variable, X, Y
from y0.examples import napkin
from y0.graph import NxMixedGraph

Z1, Z2 = map(Variable, ("Z1", "Z2"))


class TestVermaConstraints(unittest.TestCase):
"""Test getting Verma constraints."""

def assert_verma_constraints(self, graph: NxMixedGraph, expected: Set[VermaConstraint]):
"""Assert that the graph has the correct conditional independencies."""
verma_constraints = get_verma_constraints(graph)
self.assertTrue(
all(verma_constraint.is_canonical for verma_constraint in verma_constraints),
msg="one or more of the returned VermaConstraint instances are not canonical",
)
self.assertEqual(expected, verma_constraints)

def test_napkin(self):
"""Test getting Verma constraints on the napkin graph."""
# TODO how is Q[Y](Y, X) <-> V1 represented as an expression?
# TODO it also can be represented as r1:
r1 = Sum[Z2](P(Y | (Z2 | Z1, X)) * P(X | (Z1, Z2)) * P(Z2)) / Sum[Z2](
P(X | (Z2, Z1) * P(Z2))
)

c1 = VermaConstraint(..., ("R",))
self.assert_verma_constraints(napkin, {c1})
2 changes: 1 addition & 1 deletion tests/test_causaleffect.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@
from y0.examples import examples, verma_1

try:
from y0.causaleffect import r_get_verma_constraints
from y0.algorithm.verma_constraints import r_get_verma_constraints
from y0.r_utils import CAUSALEFFECT, IGRAPH
from y0.struct import VermaConstraint
except ImportError: # rpy2 is not installed
Expand Down