-
Notifications
You must be signed in to change notification settings - Fork 0
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Make distributions and graph talk to each other #30
Changes from all commits
03a751f
ddb7d52
e8a4592
d7c5a60
bd50473
02695eb
e78a06c
bbe9284
0a9a7b2
711ca46
e460f4f
ef667db
87fa819
df339e7
4b457fc
b222449
ac5dbdc
fa7bfe7
ed93108
ed02c31
cc2f1ec
8c01b7d
06a3946
8c7a517
c519c6e
5651633
6fb79be
969f0fe
66d9317
ee6cfef
caba090
b45ce2c
f6e69f9
0707856
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -58,6 +58,8 @@ def __init__(self, mean: ArrayCompatible, cov: ArrayCompatible) -> None: | |||||
cov (ArrayCompatible): Matrix of covariates, $\Sigma$. | ||||||
|
||||||
""" | ||||||
mean = jnp.atleast_1d(mean) | ||||||
cov = jnp.atleast_2d(cov) | ||||||
super().__init__(_Normal(mean, cov), label=f"({mean.ndim}-dim) Normal") | ||||||
|
||||||
|
||||||
|
@@ -76,7 +78,7 @@ def __init__(self) -> None: | |||||
"""Create a family of normal distributions.""" | ||||||
super().__init__(Normal, family_name="Normal") | ||||||
|
||||||
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: | ||||||
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003 | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. mypy was very unhappy with everything I tried putting here. @willGraham01: Any idea what it should be? There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Also, we originally had this as the following, but either ruff or mypy was not happy that the base class didn't have the
Suggested change
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think this is I think (for now), leave things as they are ( |
||||||
r""" | ||||||
Construct a normal distribution with the given mean and covariates. | ||||||
|
||||||
|
@@ -85,4 +87,4 @@ def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: | |||||
cov (ArrayCompatible): Matrix of covariates, $\Sigma$. | ||||||
|
||||||
""" | ||||||
return super().construct(mean, cov) | ||||||
return super().construct(mean=mean, cov=cov) |
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -3,48 +3,17 @@ | |
from __future__ import annotations | ||
|
||
import typing | ||
from abc import ABC, abstractmethod | ||
from abc import abstractmethod | ||
|
||
import jax | ||
import numpy as np | ||
|
||
if typing.TYPE_CHECKING: | ||
import numpy.typing as npt | ||
|
||
from causalprog._abc.labelled import Labelled | ||
|
||
|
||
class Distribution(ABC): | ||
"""Placeholder class.""" | ||
|
||
@abstractmethod | ||
def sample( | ||
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int | ||
) -> npt.NDArray[float]: | ||
"""Sample.""" | ||
|
||
|
||
class NormalDistribution(Distribution): | ||
"""Normal distribution.""" | ||
from causalprog.distribution.family import DistributionFamily | ||
|
||
def __init__(self, mean: str | float = 0.0, std_dev: str | float = 1.0) -> None: | ||
"""Initialise.""" | ||
self.mean = mean | ||
self.std_dev = std_dev | ||
|
||
def sample( | ||
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int | ||
) -> npt.NDArray[float]: | ||
"""Sample a normal distribution with mean 1.""" | ||
values = np.random.normal(0.0, 1.0, samples) # noqa: NPY002 | ||
if isinstance(self.std_dev, str): | ||
values *= sampled_dependencies[self.std_dev] | ||
else: | ||
values *= self.std_dev | ||
if isinstance(self.mean, str): | ||
values += sampled_dependencies[self.mean] | ||
else: | ||
values += self.mean | ||
return values | ||
from causalprog._abc.labelled import Labelled | ||
|
||
|
||
class Node(Labelled): | ||
|
@@ -57,7 +26,10 @@ def __init__(self, label: str, *, is_outcome: bool = False) -> None: | |
|
||
@abstractmethod | ||
def sample( | ||
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int | ||
self, | ||
sampled_dependencies: dict[str, npt.NDArray[float]], | ||
samples: int, | ||
rng_key: jax.Array, | ||
) -> float: | ||
"""Sample a value from the node.""" | ||
|
||
|
@@ -72,20 +44,40 @@ class DistributionNode(Node): | |
|
||
def __init__( | ||
self, | ||
distribution: Distribution, | ||
distribution: DistributionFamily, | ||
label: str, | ||
*, | ||
parameters: dict[str, str] | None = None, | ||
constant_parameters: dict[str, float] | None = None, | ||
is_outcome: bool = False, | ||
) -> None: | ||
"""Initialise.""" | ||
self._dist = distribution | ||
self._constant_parameters = constant_parameters if constant_parameters else {} | ||
self._parameters = parameters if parameters else {} | ||
super().__init__(label, is_outcome=is_outcome) | ||
|
||
def sample( | ||
self, sampled_dependencies: dict[str, npt.NDArray[float]], samples: int | ||
) -> float: | ||
self, | ||
sampled_dependencies: dict[str, npt.NDArray[float]], | ||
samples: int, | ||
rng_key: jax.Array, | ||
) -> npt.NDArray[float]: | ||
"""Sample a value from the node.""" | ||
return self._dist.sample(sampled_dependencies, samples) | ||
if not self._parameters: | ||
concrete_dist = self._dist.construct(**self._constant_parameters) | ||
return concrete_dist.sample(rng_key, samples) | ||
output = np.zeros(samples) | ||
new_key = jax.random.split(rng_key, samples) | ||
for sample in range(samples): | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
parameters = { | ||
i: sampled_dependencies[j][sample] for i, j in self._parameters.items() | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
} | ||
concrete_dist = self._dist.construct( | ||
**parameters, **self._constant_parameters | ||
) | ||
output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. |
||
return output | ||
|
||
def __repr__(self) -> str: | ||
return f'DistributionNode("{self.label}")' |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
See #24