Skip to content

Commit

Permalink
Merge pull request #127 from thomasckng/transform
Browse files Browse the repository at this point in the history
Fix jim output functions
  • Loading branch information
kazewong authored Aug 4, 2024
2 parents e1800da + 0a9f58f commit 1b79a3a
Show file tree
Hide file tree
Showing 4 changed files with 10 additions and 5 deletions.
6 changes: 3 additions & 3 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,7 +155,7 @@ def print_summary(self, transform: bool = True):
training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
for sample_transform in self.sample_transforms:
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
Expand All @@ -165,7 +165,7 @@ def print_summary(self, transform: bool = True):
production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
for sample_transform in self.sample_transforms:
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
Expand Down Expand Up @@ -224,7 +224,7 @@ def get_samples(self, training: bool = False) -> dict:

chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in self.sample_transforms:
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
return chains

Expand Down
2 changes: 2 additions & 0 deletions test/integration/test_GW150914_D.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,3 +126,5 @@
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
2 changes: 2 additions & 0 deletions test/integration/test_GW150914_D_heterodyne.py
Original file line number Diff line number Diff line change
Expand Up @@ -131,3 +131,5 @@
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()
Original file line number Diff line number Diff line change
Expand Up @@ -9,8 +9,7 @@
from jimgw.single_event.likelihood import TransientLikelihoodFD
from jimgw.single_event.waveform import RippleIMRPhenomD
from jimgw.transforms import BoundToUnbound
from jimgw.single_event.transforms import ComponentMassesToChirpMassSymmetricMassRatioTransform, SkyFrameToDetectorFrameSkyPositionTransform, ComponentMassesToChirpMassMassRatioTransform, MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform
from jimgw.single_event.utils import Mc_q_to_m1_m2
from jimgw.single_event.transforms import MassRatioToSymmetricMassRatioTransform, SpinToCartesianSpinTransform
from flowMC.strategy.optimization import optimization_Adam

jax.config.update("jax_enable_x64", True)
Expand Down Expand Up @@ -139,3 +138,5 @@
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()

0 comments on commit 1b79a3a

Please sign in to comment.