Skip to content

Commit

Permalink
Reformatted files
Browse files Browse the repository at this point in the history
  • Loading branch information
xuyuon committed Jul 25, 2024
1 parent 227fc98 commit 27998ca
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 20 deletions.
5 changes: 3 additions & 2 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,6 @@
]
)


run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
Expand Down Expand Up @@ -91,6 +90,8 @@

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()
run_manager.jim.get_samples()

41 changes: 23 additions & 18 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,16 +67,14 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
@dataclass
class SingleEventRun:
seed: int

detectors: list[str]
priors: dict[
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
path: str = "./experiment"
injection_parameters: dict[str, float] = field(
default_factory=lambda: {}
)
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
default_factory=lambda: {"name": "TransientLikelihoodFD"}
Expand Down Expand Up @@ -127,7 +125,7 @@ def __init__(self, **kwargs):
else:
print("Neither run instance nor path provided.")
raise ValueError

if self.run.injection and not self.run.injection_parameters:
raise ValueError("Injection mode requires injection parameters.")

Expand Down Expand Up @@ -359,14 +357,14 @@ def plot_data(self, path: str):
plt.ylabel("Amplitude")
plt.legend()
plt.savefig(path)

def sample(self):
self.jim.sample(jax.random.PRNGKey(self.run.seed))

def get_samples(self):
return self.jim.get_samples()
def plot_corner(self, path: str="corner.jpeg", **kwargs):

def plot_corner(self, path: str = "corner.jpeg", **kwargs):
"""
plot corner plot of the samples.
"""
Expand All @@ -375,29 +373,38 @@ def plot_corner(self, path: str="corner.jpeg", **kwargs):
show_titles = kwargs.get("show_titles", True)
title_fmt = kwargs.get("title_fmt", "g")
use_math_text = kwargs.get("use_math_text", True)

samples = self.jim.get_samples()
param_names = list(samples.keys())
samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T
corner.corner(samples, labels=param_names, plot_datapoints=plot_datapoint, title_quantiles=title_quantiles, show_titles=show_titles, title_fmt=title_fmt, use_math_text=use_math_text, **kwargs)
corner.corner(
samples,
labels=param_names,
plot_datapoints=plot_datapoint,
title_quantiles=title_quantiles,
show_titles=show_titles,
title_fmt=title_fmt,
use_math_text=use_math_text,
**kwargs,
)
plt.savefig(path)
plt.close()
def plot_diagnostic(self, path: str="diagnostic.jpeg", **kwargs):

def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

plt.figure(figsize=(6, 6))
axs = [plt.subplot(2, 2, i + 1) for i in range(4)]
plt.sca(axs[0])
plt.title("log probability")
plt.plot(log_prob.mean(0))
plt.xlabel("iteration")

plt.sca(axs[1])
plt.title("NF loss")
plt.plot(loss_vals.reshape(-1))
Expand All @@ -413,8 +420,6 @@ def plot_diagnostic(self, path: str="diagnostic.jpeg", **kwargs):
plt.plot(global_accs.mean(0))
plt.xlabel("iteration")
plt.tight_layout()

plt.savefig(path)
plt.close()


0 comments on commit 27998ca

Please sign in to comment.