Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make distributions and graph talk to each other #30

Open
wants to merge 33 commits into
base: main
Choose a base branch
from

Conversation

mscroggs
Copy link
Collaborator

@mscroggs mscroggs commented Mar 20, 2025

Builds on top of #16.

Removes placeholder distribution classes in graph, and plugs in the proper distribution classes in their place

@mscroggs mscroggs marked this pull request as draft March 20, 2025 11:14
@mscroggs mscroggs changed the base branch from main to mscroggs/normal-example March 20, 2025 11:14
outcome_node_label: str,
samples: int,
*,
rng_key: jax.Array,
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

See #24

concrete_dist = self._dist.construct(
**parameters, **self._constant_parameters
)
output[sample] = concrete_dist.sample(new_key[sample], 1)[0][0]
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#25

return concrete_dist.sample(rng_key, samples)
output = np.zeros(samples)
new_key = jax.random.split(rng_key, samples)
for sample in range(samples):
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#26

new_key = jax.random.split(rng_key, samples)
for sample in range(samples):
parameters = {
i: sampled_dependencies[j][sample] for i, j in self._parameters.items()
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

#29

Base automatically changed from mscroggs/normal-example to main March 20, 2025 14:09
@mscroggs mscroggs marked this pull request as ready for review March 20, 2025 14:58
@mscroggs mscroggs requested a review from willGraham01 March 20, 2025 14:58
@@ -76,7 +78,7 @@ def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal, family_name="Normal")

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

mypy was very unhappy with everything I tried putting here. @willGraham01: Any idea what it should be?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Also, we originally had this as the following, but either ruff or mypy was not happy that the base class didn't have the *,, but having *, **parameters in the base class is a syntax error...

Suggested change
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
def construct(self, *, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is mypy expects any overwritten methods to retain a signature that is compatible with the signature in the base class. Personally I would have also thought that *, mean, cov was compatible with **parameters (since they're all keyword arguments) but in hindsight I can see why it's not (since **parameters does not necessitate ANY arguments, but *, mean, cov necessitates 2).

I think (for now), leave things as they are (*, mean, cov). I can tackle this when I re-work the DistributionFamilys using the backend-agnostic classes (which should make this awkward hand-me-down of parameters redundant).

Copy link
Collaborator

@willGraham01 willGraham01 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks like a few changes from #16 are being re-overwritten here, but otherwise just a couple of shortenings.

With regards to the construct problems, I'll tackle them in my backend-agnostic PR anyway so happy to postpone with the # noqa until then.

@@ -76,7 +78,7 @@ def __init__(self) -> None:
"""Create a family of normal distributions."""
super().__init__(Normal, family_name="Normal")

def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal:
def construct(self, mean: ArrayCompatible, cov: ArrayCompatible) -> Normal: # type: ignore # noqa: PGH003
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this is mypy expects any overwritten methods to retain a signature that is compatible with the signature in the base class. Personally I would have also thought that *, mean, cov was compatible with **parameters (since they're all keyword arguments) but in hindsight I can see why it's not (since **parameters does not necessitate ANY arguments, but *, mean, cov necessitates 2).

I think (for now), leave things as they are (*, mean, cov). I can tackle this when I re-work the DistributionFamilys using the backend-agnostic classes (which should make this awkward hand-me-down of parameters redundant).

@@ -2,7 +2,7 @@
import pytest

from causalprog.distribution.base import SampleTranslator
from causalprog.distribution.family import DistributionFamily
from causalprog.distribution.normal import DistributionFamily
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
from causalprog.distribution.normal import DistributionFamily
from causalprog.distribution.family import DistributionFamily

Not sure why it's now importing from a different module.

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's nicer to import from family as that's where this is defined?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants