diff --git a/src/jimgw/single_event/runManager.py b/src/jimgw/single_event/runManager.py index 3e5bc092..2fa00cd9 100644 --- a/src/jimgw/single_event/runManager.py +++ b/src/jimgw/single_event/runManager.py @@ -102,8 +102,10 @@ def __init__(self, **kwargs): raise ValueError("Injection mode requires injection parameters.") local_prior = self.initialize_prior() - local_likelihood = self.initialize_likelihood(local_prior) sample_transforms, likelihood_transforms = self.initialize_transforms() + local_likelihood = self.initialize_likelihood( + local_prior, sample_transforms, likelihood_transforms + ) self.jim = Jim( local_likelihood, local_prior, @@ -124,7 +126,12 @@ def load_from_path(self, path: str) -> SingleEventRun: ### Initialization functions ### - def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood: + def initialize_likelihood( + self, + prior: prior.CombinePrior, + sample_transforms: transforms.Transform, + likelihood_transforms: transforms.Transform, + ) -> SingleEventLiklihood: """ Since prior contains information about types, naming and ranges of parameters, some of the likelihood class require the prior to be initialized, such as the @@ -176,6 +183,8 @@ def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLikliho detectors, waveform, prior=prior, + sample_transforms=sample_transforms, + likelihood_transforms=likelihood_transforms, **self.run.likelihood_parameters, **self.run.data_parameters, ) @@ -221,6 +230,7 @@ def initialize_transforms( transform_class = getattr(transforms, transform["name"]) except AttributeError: raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() transform.pop("name") sample_transforms.append(transform_class(**transform)) if self.run.likelihood_transforms: @@ -239,6 +249,7 @@ def initialize_transforms( transform_class = getattr(transforms, transform["name"]) except AttributeError: raise ValueError(f"{transform['name']} not recognized.") + transform = transform.copy() transform.pop("name") likelihood_transforms.append(transform_class(**transform)) return sample_transforms, likelihood_transforms @@ -460,9 +471,13 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs): def save_summary(self, path: str = "", **kwargs): if path == "": path = self.run.path + "run_manager_summary.txt" + orig_stdout = sys.stdout sys.stdout = open(path, "wt") self.jim.print_summary() - for detector, SNR in zip(self.detectors, self.SNRs): - print("SNR of detector " + detector + " is " + str(SNR)) - networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) - print("network SNR is", networkSNR) + if self.run.injection: + for detector, SNR in zip(self.detectors, self.SNRs): + print("SNR of detector " + detector + " is " + str(SNR)) + networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5) + print("network SNR is", networkSNR) + sys.stdout.close() + sys.stdout = orig_stdout