diff --git a/example/Single_event_runManager.py b/example/Single_event_runManager.py index 3a89b561..5c678b22 100644 --- a/example/Single_event_runManager.py +++ b/example/Single_event_runManager.py @@ -28,7 +28,6 @@ ] ) - run = SingleEventRun( seed=0, detectors=["H1", "L1"], @@ -91,6 +90,8 @@ run_manager = SingleEventPERunManager(run=run) run_manager.jim.sample(jax.random.PRNGKey(42)) + +# plot the corner plot and diagnostic plot run_manager.plot_corner() run_manager.plot_diagnostic() -run_manager.jim.get_samples() + diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 124e9de5..66790c1d 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -67,16 +67,14 @@ def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl): @dataclass class SingleEventRun: seed: int - + detectors: list[str] priors: dict[ str, dict[str, Union[str, float, int, bool]] ] # Transform cannot be included in this way, add it to preset if used often. jim_parameters: dict[str, Union[str, float, int, bool, dict]] path: str = "./experiment" - injection_parameters: dict[str, float] = field( - default_factory=lambda: {} - ) + injection_parameters: dict[str, float] = field(default_factory=lambda: {}) injection: bool = False likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field( default_factory=lambda: {"name": "TransientLikelihoodFD"} @@ -127,7 +125,7 @@ def __init__(self, **kwargs): else: print("Neither run instance nor path provided.") raise ValueError - + if self.run.injection and not self.run.injection_parameters: raise ValueError("Injection mode requires injection parameters.") @@ -359,14 +357,14 @@ def plot_data(self, path: str): plt.ylabel("Amplitude") plt.legend() plt.savefig(path) - + def sample(self): self.jim.sample(jax.random.PRNGKey(self.run.seed)) - + def get_samples(self): return self.jim.get_samples() - - def plot_corner(self, path: str="corner.jpeg", **kwargs): + + def plot_corner(self, path: str = "corner.jpeg", **kwargs): """ plot corner plot of the samples. """ @@ -375,29 +373,38 @@ def plot_corner(self, path: str="corner.jpeg", **kwargs): show_titles = kwargs.get("show_titles", True) title_fmt = kwargs.get("title_fmt", "g") use_math_text = kwargs.get("use_math_text", True) - + samples = self.jim.get_samples() param_names = list(samples.keys()) samples = np.array(list(samples.values())).reshape(int(len(param_names)), -1).T - corner.corner(samples, labels=param_names, plot_datapoints=plot_datapoint, title_quantiles=title_quantiles, show_titles=show_titles, title_fmt=title_fmt, use_math_text=use_math_text, **kwargs) + corner.corner( + samples, + labels=param_names, + plot_datapoints=plot_datapoint, + title_quantiles=title_quantiles, + show_titles=show_titles, + title_fmt=title_fmt, + use_math_text=use_math_text, + **kwargs, + ) plt.savefig(path) plt.close() - - def plot_diagnostic(self, path: str="diagnostic.jpeg", **kwargs): + + def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): """ plot diagnostic plot of the samples. """ summary = self.jim.Sampler.get_sampler_state(training=True) chains, log_prob, local_accs, global_accs, loss_vals = summary.values() log_prob = np.array(log_prob) - + plt.figure(figsize=(6, 6)) axs = [plt.subplot(2, 2, i + 1) for i in range(4)] plt.sca(axs[0]) plt.title("log probability") plt.plot(log_prob.mean(0)) plt.xlabel("iteration") - + plt.sca(axs[1]) plt.title("NF loss") plt.plot(loss_vals.reshape(-1)) @@ -413,8 +420,6 @@ def plot_diagnostic(self, path: str="diagnostic.jpeg", **kwargs): plt.plot(global_accs.mean(0)) plt.xlabel("iteration") plt.tight_layout() - + plt.savefig(path) plt.close() - -