Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make distributions and graph talk to each other #30

Merged
merged 34 commits into from
Mar 24, 2025
Merged
Changes from 1 commit
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
03a751f
start working on node that store distributions
mscroggs Mar 13, 2025
ddb7d52
refactor graph, add add_node and add_edge methods
mscroggs Mar 14, 2025
e8a4592
get multiple samples at onde
mscroggs Mar 14, 2025
d7c5a60
parametrise test
mscroggs Mar 14, 2025
bd50473
make two distribution example
mscroggs Mar 14, 2025
02695eb
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 14, 2025
e78a06c
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 19, 2025
bbe9284
| None
mscroggs Mar 19, 2025
0a9a7b2
remove irrelevant nodes from test
mscroggs Mar 19, 2025
711ca46
add stdev to test
mscroggs Mar 19, 2025
e460f4f
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
ef667db
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
87fa819
make roots_down_to_outcome a method of the graph
mscroggs Mar 19, 2025
df339e7
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
4b457fc
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
b222449
default false
mscroggs Mar 19, 2025
ac5dbdc
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
fa7bfe7
don't allow temporary None labels
mscroggs Mar 19, 2025
ed93108
simpler test
mscroggs Mar 19, 2025
ed02c31
improve tests, and simplify iteration code
mscroggs Mar 19, 2025
cc2f1ec
reduce number of tests
mscroggs Mar 19, 2025
8c01b7d
number of samples must be int, don't use 0 for mean so that relative …
mscroggs Mar 20, 2025
06a3946
ruff
mscroggs Mar 20, 2025
8c7a517
add moved file
mscroggs Mar 20, 2025
c519c6e
Merge branch 'main' into mscroggs/distributions
mscroggs Mar 20, 2025
5651633
fix merge
mscroggs Mar 20, 2025
6fb79be
Update src/causalprog/_abc/labelled.py
mscroggs Mar 21, 2025
969f0fe
Update src/causalprog/graph/graph.py
mscroggs Mar 21, 2025
66d9317
Update src/causalprog/graph/node.py
mscroggs Mar 21, 2025
ee6cfef
Update src/causalprog/graph/node.py
mscroggs Mar 21, 2025
caba090
Update tests/test_graph.py
mscroggs Mar 21, 2025
b45ce2c
Merge branch 'main' into mscroggs/distributions
mscroggs Mar 21, 2025
f6e69f9
noqa
mscroggs Mar 21, 2025
0707856
Update tests/test_distributions/test_family.py
mscroggs Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
make two distribution example
mscroggs committed Mar 14, 2025
commit bd504737c2762b0554216c58222ce28296f908d5
2 changes: 1 addition & 1 deletion src/causalprog/algorithms/__init__.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,3 @@
"""Algorithms."""

from .expectation import expectation
from .expectation import expectation, standard_deviation
34 changes: 25 additions & 9 deletions src/causalprog/algorithms/expectation.py
Original file line number Diff line number Diff line change
@@ -1,21 +1,19 @@
"""Algorithm for estimating the expectation of a process represented by a graph."""
"""Algorithms for estimating the expectation and standard deviation."""

import typing

if typing.TYPE_CHECKING:
import numpy.typing as npt
import numpy as np
import numpy.typing as npt

from causalprog.graph import Graph

from .iteration import roots_down_to_outcome


def expectation(
def sample(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> float:
"""Estimate the expectation of a graph."""
) -> npt.NDArray[float]:
"""Sample data from a graph."""
if outcome_node_label is None:
outcome_node_label = graph.outcome.label

@@ -24,4 +22,22 @@ def expectation(
values: dict[str, npt.NDArray[float]] = {}
for node in nodes:
values[node.label] = node.sample(values, samples)
return values[outcome_node_label].sum() / samples
return values[outcome_node_label]


def expectation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> float:
"""Estimate the expectation of a graph."""
return sample(graph, outcome_node_label, samples).mean()


def standard_deviation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
) -> float:
"""Estimate the standard deviation of a graph."""
return np.std(sample(graph, outcome_node_label, samples))
10 changes: 5 additions & 5 deletions src/causalprog/algorithms/iteration.py
Original file line number Diff line number Diff line change
@@ -19,10 +19,10 @@ def roots_down_to_outcome(
while n < len(nodes_need_sampling):
new_n = len(nodes_need_sampling)
for node in nodes_need_sampling[n:]:
if node in pre:
for parent in pre[node]:
if parent not in nodes_need_sampling:
nodes_need_sampling.append(parent)
if node in pre and pre[node] not in nodes_need_sampling:
nodes_need_sampling.append(pre[node])
n = new_n

return [node for node in graph.depth_first_nodes if node in nodes_need_sampling]
return [
node for node in graph.depth_first_nodes[::-1] if node in nodes_need_sampling
]
2 changes: 1 addition & 1 deletion src/causalprog/graph/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
"""Creation and storage of graphs."""

from .graph import Graph
from .node import DistributionNode, Node, RootDistributionNode
from .node import DistributionNode, Node
2 changes: 1 addition & 1 deletion src/causalprog/graph/graph.py
Original file line number Diff line number Diff line change
@@ -13,7 +13,7 @@ class Graph:
def __init__(self, label: str) -> None:
"""Create end empty graph."""
self._label = label
self._graph = nx.Graph()
self._graph = nx.DiGraph()
self._nodes_by_label = {}
self._node_index = 0

87 changes: 40 additions & 47 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
@@ -2,21 +2,47 @@

from __future__ import annotations

import typing
from abc import ABC, abstractmethod

import numpy as np

if typing.TYPE_CHECKING:
import numpy.typing as npt

class DistributionFamily:

class Distribution(ABC):
"""Placeholder class."""

@abstractmethod
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample."""


class Distribution:
"""Placeholder class."""
class NormalDistribution(Distribution):
"""Normal distribution."""

def sample(self, samples: int) -> float:
def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None:
"""Initialise."""
self.mean = mean
self.std_dev = std_dev

def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample a normal distribution with mean 1."""
return np.random.normal(1.0, 1.0, samples) # noqa: NPY002
values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002
if isinstance(self.std_dev, str):
values *= sampled_dependencies[self.std_dev]
else:
values *= self.std_dev
if isinstance(self.mean, str):
values += sampled_dependencies[self.mean]
else:
values += self.mean
return values


class Node(ABC):
@@ -36,22 +62,19 @@ def label(self) -> str:
return self._label

@abstractmethod
def sample(self, sampled_dependencies: dict[str, float], samples: int) -> float:
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> float:
"""Sample a value from the node."""

@property
@abstractmethod
def is_root(self) -> bool:
"""Identify if the node is a root."""

@property
def is_outcome(self) -> bool:
"""Identify if the node is an outcome."""
return self._is_outcome


class RootDistributionNode(Node):
"""A root node containing a distribution family."""
class DistributionNode(Node):
"""A node containing a distribution."""

def __init__(
self,
@@ -64,41 +87,11 @@ def __init__(
self._dist = distribution
super().__init__(label, is_outcome=is_outcome)

def sample(self, _sampled_dependencies: dict[str, float], samples: int) -> float:
"""Sample a value from the node."""
return self._dist.sample(samples)

def __repr__(self) -> str:
return f'RootDistributionNode("{self.label}")'

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
return True


class DistributionNode(Node):
"""A node containing a distribution family that depends on its parents."""

def __init__(
self,
family: DistributionFamily,
label: str | None = None,
*,
is_outcome: bool = False,
) -> None:
"""Initialise."""
self._dfamily = family
super().__init__(label, is_outcome=is_outcome)

def sample(self, sampled_dependencies: dict[str, float], samples: int) -> float:
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> float:
"""Sample a value from the node."""
raise NotImplementedError
return self._dist.sample(sampled_dependencies, samples)

def __repr__(self) -> str:
return f'DistributionNode("{self.label}")'

@property
def is_root(self) -> bool:
"""Identify if the node is a root."""
return False
89 changes: 62 additions & 27 deletions tests/test_graph.py
Original file line number Diff line number Diff line change
@@ -9,11 +9,11 @@


def test_label():
family = causalprog.graph.node.DistributionFamily()
node = causalprog.graph.RootDistributionNode(family)
node2 = causalprog.graph.RootDistributionNode(family, "node1")
node3 = causalprog.graph.RootDistributionNode(family, "Y")
node4 = causalprog.graph.DistributionNode(family)
d = causalprog.graph.node.NormalDistribution()
node = causalprog.graph.DistributionNode(d)
node2 = causalprog.graph.DistributionNode(d, "node1")
node3 = causalprog.graph.DistributionNode(d, "Y")
node4 = causalprog.graph.DistributionNode(d)
node_copy = node

assert node._label == node_copy._label # noqa: SLF001
@@ -41,22 +41,22 @@ def test_label():


def test_duplicate_label():
family = causalprog.graph.node.DistributionFamily()
d = causalprog.graph.node.NormalDistribution()

graph = causalprog.graph.Graph("G0")
graph.add_node(causalprog.graph.RootDistributionNode(family, "X"))
graph.add_node(causalprog.graph.DistributionNode(d, "X"))
with pytest.raises(ValueError, match=re.escape("Duplicate node label: X")):
graph.add_node(causalprog.graph.RootDistributionNode(family, "X"))
graph.add_node(causalprog.graph.DistributionNode(d, "X"))


def test_simple_graph():
family = causalprog.graph.node.DistributionFamily()
n_x = causalprog.graph.RootDistributionNode(family, "N_X")
n_m = causalprog.graph.RootDistributionNode(family, "N_M")
u_y = causalprog.graph.RootDistributionNode(family, "U_Y")
x = causalprog.graph.DistributionNode(family, "X")
m = causalprog.graph.DistributionNode(family, "M")
y = causalprog.graph.DistributionNode(family, "Y", is_outcome=True)
d = causalprog.graph.node.NormalDistribution()
n_x = causalprog.graph.DistributionNode(d, "N_X")
n_m = causalprog.graph.DistributionNode(d, "N_M")
u_y = causalprog.graph.DistributionNode(d, "U_Y")
x = causalprog.graph.DistributionNode(d, "X")
m = causalprog.graph.DistributionNode(d, "M")
y = causalprog.graph.DistributionNode(d, "Y", is_outcome=True)

graph = causalprog.graph.Graph("G0")
graph.add_edge(n_x, x)
@@ -69,15 +69,15 @@ def test_simple_graph():


def test_simple_graph_build_using_labels():
family = causalprog.graph.node.DistributionFamily()
d = causalprog.graph.node.NormalDistribution()

graph = causalprog.graph.Graph("G0")
graph.add_node(causalprog.graph.RootDistributionNode(family, "N_X"))
graph.add_node(causalprog.graph.RootDistributionNode(family, "N_M"))
graph.add_node(causalprog.graph.RootDistributionNode(family, "U_Y"))
graph.add_node(causalprog.graph.DistributionNode(family, "X"))
graph.add_node(causalprog.graph.DistributionNode(family, "M"))
graph.add_node(causalprog.graph.DistributionNode(family, "Y", is_outcome=True))
graph.add_node(causalprog.graph.DistributionNode(d, "N_X"))
graph.add_node(causalprog.graph.DistributionNode(d, "N_M"))
graph.add_node(causalprog.graph.DistributionNode(d, "U_Y"))
graph.add_node(causalprog.graph.DistributionNode(d, "X"))
graph.add_node(causalprog.graph.DistributionNode(d, "M"))
graph.add_node(causalprog.graph.DistributionNode(d, "Y", is_outcome=True))

graph.add_edge("N_X", "X")
graph.add_edge("N_M", "M")
@@ -88,22 +88,57 @@ def test_simple_graph_build_using_labels():
assert graph.label == "G0"


@pytest.mark.parametrize("mean", [1.0, 2.0])
@pytest.mark.parametrize("stdev", [0.8, 1.0])
@pytest.mark.parametrize(
("samples", "rtol"),
[
(10, 1),
(1000, 1e-1),
(100000, 1e-2),
(100000, 1e-3),
(10000000, 1e-3),
],
)
def test_single_normal_node(samples, rtol):
normal = causalprog.graph.node.Distribution()
node = causalprog.graph.RootDistributionNode(normal, "X", is_outcome=True)
def test_single_normal_node(samples, rtol, mean, stdev):
normal = causalprog.graph.node.NormalDistribution(mean, stdev)
node = causalprog.graph.DistributionNode(normal, "X", is_outcome=True)

graph = causalprog.graph.Graph("G0")
graph.add_node(node)

assert np.isclose(
causalprog.algorithms.expectation(graph, samples=samples), 1.0, rtol=rtol
causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol
)
assert np.isclose(
causalprog.algorithms.standard_deviation(graph, samples=samples),
stdev,
rtol=rtol,
)


@pytest.mark.parametrize("mean", [1.0, 2.0])
@pytest.mark.parametrize("stdev", [0.8, 1.0])
@pytest.mark.parametrize("stdev2", [0.8, 1.0])
@pytest.mark.parametrize(
("samples", "rtol"),
[
(100, 1),
(10000, 1e-1),
(1000000, 1e-2),
],
)
def test_two_node_graph(samples, rtol, mean, stdev, stdev2):
normal = causalprog.graph.node.NormalDistribution(mean, stdev)
normal2 = causalprog.graph.node.NormalDistribution("UX", stdev2)

graph = causalprog.graph.Graph("G0")
graph.add_node(causalprog.graph.DistributionNode(normal, "UX"))
graph.add_node(causalprog.graph.DistributionNode(normal, "Y"))
graph.add_node(causalprog.graph.DistributionNode(normal, "Z"))
graph.add_node(causalprog.graph.DistributionNode(normal2, "X", is_outcome=True))
graph.add_edge("UX", "X")
graph.add_edge("Y", "Z")

assert np.isclose(
causalprog.algorithms.expectation(graph, samples=samples), mean, rtol=rtol
)