Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make distributions and graph talk to each other #30

Merged
merged 34 commits into from
Mar 24, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
34 commits
Select commit Hold shift + click to select a range
03a751f
start working on node that store distributions
mscroggs Mar 13, 2025
ddb7d52
refactor graph, add add_node and add_edge methods
mscroggs Mar 14, 2025
e8a4592
get multiple samples at onde
mscroggs Mar 14, 2025
d7c5a60
parametrise test
mscroggs Mar 14, 2025
bd50473
make two distribution example
mscroggs Mar 14, 2025
02695eb
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 14, 2025
e78a06c
Merge branch 'main' into mscroggs/normal-example
mscroggs Mar 19, 2025
bbe9284
| None
mscroggs Mar 19, 2025
0a9a7b2
remove irrelevant nodes from test
mscroggs Mar 19, 2025
711ca46
add stdev to test
mscroggs Mar 19, 2025
e460f4f
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
ef667db
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
87fa819
make roots_down_to_outcome a method of the graph
mscroggs Mar 19, 2025
df339e7
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
4b457fc
Update src/causalprog/graph/graph.py
mscroggs Mar 19, 2025
b222449
default false
mscroggs Mar 19, 2025
ac5dbdc
Merge branch 'mscroggs/normal-example' of github.com:UCL/causalprog i…
mscroggs Mar 19, 2025
fa7bfe7
don't allow temporary None labels
mscroggs Mar 19, 2025
ed93108
simpler test
mscroggs Mar 19, 2025
ed02c31
improve tests, and simplify iteration code
mscroggs Mar 19, 2025
cc2f1ec
reduce number of tests
mscroggs Mar 19, 2025
8c01b7d
number of samples must be int, don't use 0 for mean so that relative …
mscroggs Mar 20, 2025
06a3946
ruff
mscroggs Mar 20, 2025
8c7a517
add moved file
mscroggs Mar 20, 2025
c519c6e
Merge branch 'main' into mscroggs/distributions
mscroggs Mar 20, 2025
5651633
fix merge
mscroggs Mar 20, 2025
6fb79be
Update src/causalprog/_abc/labelled.py
mscroggs Mar 21, 2025
969f0fe
Update src/causalprog/graph/graph.py
mscroggs Mar 21, 2025
66d9317
Update src/causalprog/graph/node.py
mscroggs Mar 21, 2025
ee6cfef
Update src/causalprog/graph/node.py
mscroggs Mar 21, 2025
caba090
Update tests/test_graph.py
mscroggs Mar 21, 2025
b45ce2c
Merge branch 'main' into mscroggs/distributions
mscroggs Mar 21, 2025
f6e69f9
noqa
mscroggs Mar 21, 2025
0707856
Update tests/test_distributions/test_family.py
mscroggs Mar 24, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 19 additions & 11 deletions src/causalprog/algorithms/expectation.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,17 @@
"""Algorithms for estimating the expectation and standard deviation."""

import numpy as np
import jax
import numpy.typing as npt

from causalprog.graph import Graph


def sample(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #24

) -> npt.NDArray[float]:
"""Sample data from a graph."""
if outcome_node_label is None:
Expand All @@ -18,24 +20,30 @@ def sample(
nodes = graph.roots_down_to_outcome(outcome_node_label)

values: dict[str, npt.NDArray[float]] = {}
for node in nodes:
values[node.label] = node.sample(values, samples)
keys = jax.random.split(rng_key, len(nodes))

for node, key in zip(nodes, keys, strict=False):
values[node.label] = node.sample(values, samples, rng_key=key)
return values[outcome_node_label]


def expectation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
) -> float:
"""Estimate the expectation of a graph."""
return sample(graph, outcome_node_label, samples).mean()
return sample(graph, outcome_node_label, samples, rng_key=rng_key).mean()


def standard_deviation(
graph: Graph,
outcome_node_label: str | None = None,
samples: int = 1000,
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
) -> float:
"""Estimate the standard deviation of a graph."""
return np.std(sample(graph, outcome_node_label, samples))
return sample(graph, outcome_node_label, samples, rng_key=rng_key).std()
10 changes: 5 additions & 5 deletions src/causalprog/distribution/family.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,8 @@ class DistributionFamily(Generic[CreatesDistribution], Labelled):
@property
def _member(self) -> Callable[..., Distribution]:
"""Constructor method for family members, given parameters."""
return lambda *parameters: Distribution(
self._family(*parameters),
return lambda **parameters: Distribution(
self._family(**parameters),
backend_translator=self._family_translator,
)

Expand All @@ -67,13 +67,13 @@ def __init__(
self._family = backend_family
self._family_translator = backend_translator

def construct(self, *parameters: ArrayLike) -> Distribution:
def construct(self, **parameters: ArrayLike) -> Distribution:
"""
Create a distribution from an explicit set of parameters.

Args:
*parameters (ArrayLike): Parameters that define a member of this family,
**parameters (ArrayLike): Parameters that define a member of this family,
passed as sequential arguments.

"""
return self._member(*parameters)
return self._member(**parameters)
6 changes: 4 additions & 2 deletions src/causalprog/distribution/normal.py
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,8 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None:
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.

"""
mean = jnp.atleast_1d(mean)
cov = jnp.atleast_2d(cov)
super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal")


Expand All @@ -76,7 +78,7 @@ def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal, family_name="Normal")

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy was very unhappy with everything I tried putting here. @willGraham01: Any idea what it should be?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we originally had this as the following, but either ruff or mypy was not happy that the base class didn't have the *,, but having *, **parameters in the base class is a syntax error...

Suggested change
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
def construct(self, *, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is mypy expects any overwritten methods to retain a signature that is compatible with the signature in the base class. Personally I would have also thought that *, mean, cov was compatible with **parameters (since they're all keyword arguments) but in hindsight I can see why it's not (since **parameters does not necessitate ANY arguments, but *, mean, cov necessitates 2).

I think (for now), leave things as they are (*, mean, cov). I can tackle this when I re-work the DistributionFamilys using the backend-agnostic classes (which should make this awkward hand-me-down of parameters redundant).

r"""
Construct a normal distribution with the given mean and covariates.

Expand All @@ -85,4 +87,4 @@ def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
cov (ArrayCompatible): Matrix of covariates, $\Sigma$.

"""
return super().construct(mean, cov)
return super().construct(mean=mean, cov=cov)
72 changes: 32 additions & 40 deletions src/causalprog/graph/node.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,48 +3,17 @@
from __future__ import annotations

import typing
from abc import ABC, abstractmethod
from abc import abstractmethod

import jax
import numpy as np

if typing.TYPE_CHECKING:
import numpy.typing as npt

from causalprog._abc.labelled import Labelled


class Distribution(ABC):
"""Placeholder class."""

@abstractmethod
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample."""


class NormalDistribution(Distribution):
"""Normal distribution."""
from causalprog.distribution.family import DistributionFamily

def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None:
"""Initialise."""
self.mean = mean
self.std_dev = std_dev

def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> npt.NDArray[float]:
"""Sample a normal distribution with mean 1."""
values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002
if isinstance(self.std_dev, str):
values *= sampled_dependencies[self.std_dev]
else:
values *= self.std_dev
if isinstance(self.mean, str):
values += sampled_dependencies[self.mean]
else:
values += self.mean
return values
from causalprog._abc.labelled import Labelled


class Node(Labelled):
Expand All @@ -57,7 +26,10 @@ def __init__(self, label: str, *, is_outcome: bool = False) -> None:

@abstractmethod
def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
self,
sampled_dependencies: dict[str, npt.NDArray[float]],
samples: int,
rng_key: jax.Array,
) -> float:
"""Sample a value from the node."""

Expand All @@ -72,20 +44,40 @@ class DistributionNode(Node):

def __init__(
self,
distribution: Distribution,
distribution: DistributionFamily,
label: str,
*,
parameters: dict[str, str] | None = None,
constant_parameters: dict[str, float] | None = None,
is_outcome: bool = False,
) -> None:
"""Initialise."""
self._dist = distribution
self._constant_parameters = constant_parameters if constant_parameters else {}
self._parameters = parameters if parameters else {}
super().__init__(label, is_outcome=is_outcome)

def sample(
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int
) -> float:
self,
sampled_dependencies: dict[str, npt.NDArray[float]],
samples: int,
rng_key: jax.Array,
) -> npt.NDArray[float]:
"""Sample a value from the node."""
return self._dist.sample(sampled_dependencies, samples)
if not self._parameters:
concrete_dist = self._dist.construct(**self._constant_parameters)
return concrete_dist.sample(rng_key, samples)
output = np.zeros(samples)
new_key = jax.random.split(rng_key, samples)
for sample in range(samples):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#26

parameters = {
i: sampled_dependencies[j][sample] for i, j in self._parameters.items()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#29

}
concrete_dist = self._dist.construct(
**parameters, **self._constant_parameters
)
output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#25

return output

def __repr__(self) -> str:
return f'DistributionNode("{self.label}")'
4 changes: 2 additions & 2 deletions tests/test_distributions/conftest.py → tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@ def rng_key(seed: int):


@pytest.fixture
def n_dim_std_normal(request) -> tuple[Array, Array]:
def n_dim_std_normal(request) -> dict[str, Array]:
"""
Mean and covariance matrix of the n-dimensional standard normal distribution.
Expand All @@ -24,4 +24,4 @@ def n_dim_std_normal(request) -> tuple[Array, Array]:
n_dims = request.param
mean = jnp.array([0.0] * n_dims)
cov = jnp.diag(jnp.array([1.0] * n_dims))
return mean, cov
return {"mean": mean, "cov": cov}
6 changes: 4 additions & 2 deletions tests/test_distributions/test_family.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,8 +19,10 @@ def test_builder_matches_backend(n_dim_std_normal) -> None:
mnv = distrax.MultivariateNormalFullCovariance

mnv_family = DistributionFamily(mnv, SampleTranslator(rng_key="seed"))
via_family = mnv_family.construct(*n_dim_std_normal)
via_backend = mnv(*n_dim_std_normal)
via_family = mnv_family.construct(
loc=n_dim_std_normal["mean"], covariance_matrix=n_dim_std_normal["cov"]
)
via_backend = mnv(n_dim_std_normal["mean"], n_dim_std_normal["cov"])

assert via_backend.kl_divergence(via_family.get_dist()) == pytest.approx(0.0)
assert via_family.get_dist().kl_divergence(via_backend) == pytest.approx(0.0)
Loading