Skip to content

Commit

Permalink
Merge pull request #135 from thomasckng/run-manager
Browse files Browse the repository at this point in the history
Update Run Manager
  • Loading branch information
ThibeauWouters authored Aug 22, 2024
2 parents 38bb277 + e5b6e84 commit 15a103d
Show file tree
Hide file tree
Showing 8 changed files with 327 additions and 418 deletions.
108 changes: 53 additions & 55 deletions example/Single_event_runManager.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,3 @@

import jax
import jax.numpy as jnp

Expand All @@ -12,57 +11,50 @@
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
mass_matrix = mass_matrix * 3e-3
local_sampler_arg = {"step_size": mass_matrix}
bounds = jnp.array(
[
[10.0, 40.0],
[0.125, 1.0],
[-1.0, 1.0],
[-1.0, 1.0],
[0.0, 2000.0],
[-0.05, 0.05],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
[0.0, jnp.pi],
[0.0, 2 * jnp.pi],
[-1.0, 1.0],
]
)

run = SingleEventRun(
seed=0,
detectors=["H1", "L1"],
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
},
priors={
"M_c": {"name": "Unconstrained_Uniform", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "MassRatio"},
"s1_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "Unconstrained_Uniform", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2000.0},
"t_c": {"name": "Unconstrained_Uniform", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"cos_iota": {"name": "CosIota"},
"psi": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "Unconstrained_Uniform", "xmin": 0.0, "xmax": 2 * jnp.pi},
"sin_dec": {"name": "SinDec"},
"M_c": {"name": "UniformPrior", "xmin": 10.0, "xmax": 80.0},
"q": {"name": "UniformPrior", "xmin": 0.0, "xmax": 1.0},
"s1_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"s2_z": {"name": "UniformPrior", "xmin": -1.0, "xmax": 1.0},
"d_L": {"name": "UniformPrior", "xmin": 1.0, "xmax": 2000.0},
"t_c": {"name": "UniformPrior", "xmin": -0.05, "xmax": 0.05},
"phase_c": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"iota": {"name": "SinePrior"},
"psi": {"name": "UniformPrior", "xmin": 0.0, "xmax": jnp.pi},
"ra": {"name": "UniformPrior", "xmin": 0.0, "xmax": 2 * jnp.pi},
"dec": {"name": "CosinePrior"},
},
waveform_parameters={"name": "RippleIMRPhenomD", "f_ref": 20.0},
jim_parameters={
"n_loop_training": 10,
"n_loop_production": 10,
"n_local_steps": 15,
"n_global_steps": 15,
"n_chains": 500,
"n_epochs": 10,
"learning_rate": 0.001,
"n_max_examples": 45000,
"momentum": 0.9,
"batch_size": 50000,
"use_global": True,
"keep_quantile": 0.0,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
likelihood_parameters={"name": "TransientLikelihoodFD", "bounds": bounds},
likelihood_parameters={"name": "TransientLikelihoodFD"},
sample_transforms=[
{"name": "BoundToUnbound", "name_mapping": [["M_c"], ["M_c_unbounded"]], "original_lower_bound": 10.0, "original_upper_bound": 80.0,},
{"name": "BoundToUnbound", "name_mapping": [["q"], ["q_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s1_z"], ["s1_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["s2_z"], ["s2_z_unbounded"]], "original_lower_bound": -1.0, "original_upper_bound": 1.0,},
{"name": "BoundToUnbound", "name_mapping": [["d_L"], ["d_L_unbounded"]], "original_lower_bound": 1.0, "original_upper_bound": 2000.0,},
{"name": "BoundToUnbound", "name_mapping": [["t_c"], ["t_c_unbounded"]], "original_lower_bound": -0.05, "original_upper_bound": 0.05,},
{"name": "BoundToUnbound", "name_mapping": [["phase_c"], ["phase_c_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["iota"], ["iota_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["psi"], ["psi_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["ra"], ["ra_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": 2 * jnp.pi,},
{"name": "BoundToUnbound", "name_mapping": [["dec"], ["dec_unbounded"]], "original_lower_bound": 0.0, "original_upper_bound": jnp.pi,},
],
likelihood_transforms=[
{"name": "MassRatioToSymmetricMassRatioTransform", "name_mapping": [["q"], ["eta"]]},
],
injection=True,
injection_parameters={
"M_c": 28.6,
Expand All @@ -77,22 +69,28 @@
"ra": 1.2,
"dec": 0.3,
},
data_parameters={
"trigger_time": 1126259462.4,
"duration": 4,
"post_trigger_duration": 2,
"f_min": 20.0,
"f_max": 1024.0,
"tukey_alpha": 0.2,
"f_sampling": 4096.0,
jim_parameters={
"n_loop_training": 100,
"n_loop_production": 20,
"n_local_steps": 10,
"n_global_steps": 1000,
"n_chains": 500,
"n_epochs": 30,
"learning_rate": 1e-4,
"n_max_examples": 30000,
"momentum": 0.9,
"batch_size": 30000,
"use_global": True,
"train_thinning": 1,
"output_thinning": 10,
"local_sampler_arg": local_sampler_arg,
},
)

run_manager = SingleEventPERunManager(run=run)
run_manager.jim.sample(jax.random.PRNGKey(42))
run_manager.sample()

# plot the corner plot and diagnostic plot
run_manager.plot_corner()
run_manager.plot_diagnostic()
run_manager.save_summary()

2 changes: 1 addition & 1 deletion src/jimgw/single_event/detector.py
Original file line number Diff line number Diff line change
Expand Up @@ -373,7 +373,7 @@ def inject_signal(
h_sky: dict[str, Float[Array, " n_sample"]],
params: dict[str, Float],
psd_file: str = "",
) -> None:
) -> tuple[Float, Float]:
"""
Inject a signal into the detector data.
Expand Down
154 changes: 87 additions & 67 deletions src/jimgw/single_event/runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,59 +12,22 @@
from jaxlib.xla_extension import ArrayImpl
from jaxtyping import Array, Float, PyTree

from jimgw import prior
from jimgw import prior, transforms
from jimgw.single_event import prior as single_event_prior
from jimgw.single_event import transforms as single_event_transforms
from jimgw.base import RunManager
from jimgw.jim import Jim
from jimgw.single_event.detector import Detector, detector_preset
from jimgw.single_event.likelihood import SingleEventLiklihood, likelihood_presets
from jimgw.single_event.waveform import Waveform, waveform_preset



def jaxarray_representer(dumper: yaml.Dumper, data: ArrayImpl):
return dumper.represent_list(data.tolist())


yaml.add_representer(ArrayImpl, jaxarray_representer) # type: ignore

prior_presets = {
"Unconstrained_Uniform": prior.Unconstrained_Uniform,
"Uniform": prior.Uniform,
"Sphere": prior.Sphere,
"AlignedSpin": prior.AlignedSpin,
"PowerLaw": prior.PowerLaw,
"Composite": prior.Composite,
"MassRatio": lambda **kwargs: prior.Uniform(
0.125,
1.0,
naming=["q"],
transforms={"q": ("eta", lambda params: params["q"] / (1 + params["q"]) ** 2)},
),
"CosIota": lambda **kwargs: prior.Uniform(
-1.0,
1.0,
naming=["cos_iota"],
transforms={
"cos_iota": (
"iota",
lambda params: jnp.arccos(params["cos_iota"]),
)
},
),
"SinDec": lambda **kwargs: prior.Uniform(
-1.0,
1.0,
naming=["sin_dec"],
transforms={
"sin_dec": (
"dec",
lambda params: jnp.arcsin(params["sin_dec"]),
)
},
),
"EarthFrame": prior.EarthFrame,
}


@dataclass
class SingleEventRun:
Expand All @@ -75,7 +38,7 @@ class SingleEventRun:
str, dict[str, Union[str, float, int, bool]]
] # Transform cannot be included in this way, add it to preset if used often.
jim_parameters: dict[str, Union[str, float, int, bool, dict]]
path: str = "./experiment"
path: str = "single_event_run"
injection_parameters: dict[str, float] = field(default_factory=lambda: {})
injection: bool = False
likelihood_parameters: dict[str, Union[str, float, int, bool, PyTree]] = field(
Expand All @@ -95,7 +58,12 @@ class SingleEventRun:
"f_sampling": 4096.0,
}
)

sample_transforms: list[dict[str, Union[str, float, int, bool]]] = field(
default_factory=lambda: []
)
likelihood_transforms: list[dict[str, Union[str, float, int, bool]]] = field(
default_factory=lambda: []
)


class SingleEventPERunManager(RunManager):
Expand Down Expand Up @@ -135,7 +103,14 @@ def __init__(self, **kwargs):

local_prior = self.initialize_prior()
local_likelihood = self.initialize_likelihood(local_prior)
self.jim = Jim(local_likelihood, local_prior, **self.run.jim_parameters)
sample_transforms, likelihood_transforms = self.initialize_transforms()
self.jim = Jim(
local_likelihood,
local_prior,
sample_transforms,
likelihood_transforms,
**self.run.jim_parameters,
)

def save(self, path: str):
output_dict = asdict(self.run)
Expand All @@ -149,7 +124,7 @@ def load_from_path(self, path: str) -> SingleEventRun:

### Initialization functions ###

def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
def initialize_likelihood(self, prior: prior.CombinePrior) -> SingleEventLiklihood:
"""
Since prior contains information about types, naming and ranges of parameters,
some of the likelihood class require the prior to be initialized, such as the
Expand Down Expand Up @@ -192,11 +167,11 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
key, subkey = jax.random.split(jax.random.PRNGKey(self.run.seed + 1901))
SNRs = []
for detector in detectors:
optimal_SNR,_ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore
optimal_SNR, _ = detector.inject_signal(subkey, freqs, h_sky, detector_parameters) # type: ignore
SNRs.append(optimal_SNR)
key, subkey = jax.random.split(key)
self.SNRs = SNRs

return likelihood_presets[name](
detectors,
waveform,
Expand All @@ -205,23 +180,67 @@ def initialize_likelihood(self, prior: prior.Prior) -> SingleEventLiklihood:
**self.run.data_parameters,
)

def initialize_prior(self) -> prior.Prior:
def initialize_prior(self) -> prior.CombinePrior:
priors = []
for name, parameters in self.run.priors.items():
if parameters["name"] not in prior_presets:
raise ValueError(f"Prior {name} not recognized.")
if parameters["name"] == "EarthFrame":
priors.append(
prior.EarthFrame(
gps=self.run.data_parameters["trigger_time"],
ifos=self.run.detectors,
assert isinstance(
parameters, dict
), "Prior parameters must be a dictionary."
assert "name" in parameters, "Prior name must be provided."
assert isinstance(parameters["name"], str), "Prior name must be a string."
try:
prior_class = getattr(single_event_prior, parameters["name"])
except AttributeError:
try:
prior_class = getattr(prior, parameters["name"])
except AttributeError:
raise ValueError(f"{parameters['name']} not recognized.")
parameters.pop("name")
priors.append(prior_class(parameter_names=[name], **parameters))
return prior.CombinePrior(priors)

def initialize_transforms(
self,
) -> tuple[list[transforms.BijectiveTransform], list[transforms.NtoMTransform]]:
sample_transforms = []
likelihood_transforms = []
if self.run.sample_transforms:
for transform in self.run.sample_transforms:
assert isinstance(transform, dict), "Transform must be a dictionary."
assert "name" in transform, "Transform name must be provided."
assert isinstance(
transform["name"], str
), "Transform name must be a string."
try:
transform_class = getattr(
single_event_transforms, transform["name"]
)
)
else:
priors.append(
prior_presets[parameters["name"]](naming=[name], **parameters)
)
return prior.Composite(priors)
except AttributeError:
try:
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform.pop("name")
sample_transforms.append(transform_class(**transform))
if self.run.likelihood_transforms:
for transform in self.run.likelihood_transforms:
assert isinstance(transform, dict), "Transform must be a dictionary."
assert "name" in transform, "Transform name must be provided."
assert isinstance(
transform["name"], str
), "Transform name must be a string."
try:
transform_class = getattr(
single_event_transforms, transform["name"]
)
except AttributeError:
try:
transform_class = getattr(transforms, transform["name"])
except AttributeError:
raise ValueError(f"{transform['name']} not recognized.")
transform.pop("name")
likelihood_transforms.append(transform_class(**transform))
return sample_transforms, likelihood_transforms

def initialize_detector(self) -> list[Detector]:
"""
Expand Down Expand Up @@ -403,7 +422,7 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
"""
plot diagnostic plot of the samples.
"""
summary = self.jim.Sampler.get_sampler_state(training=True)
summary = self.jim.sampler.get_sampler_state(training=True)
chains, log_prob, local_accs, global_accs, loss_vals = summary.values()
log_prob = np.array(log_prob)

Expand Down Expand Up @@ -437,11 +456,12 @@ def plot_diagnostic(self, path: str = "diagnostic.jpeg", **kwargs):
plt.savefig(path)
plt.close()

def save_summary(self, path: str = "run_manager_summary.txt", **kwargs):
sys.stdout = open(path,'wt')
def save_summary(self, path: str = "", **kwargs):
if path == "":
path = self.run.path + "run_manager_summary.txt"
sys.stdout = open(path, "wt")
self.jim.print_summary()
#print(self.SNRs)
for detector, SNR in zip(self.detectors, self.SNRs):
print('SNR of detector ' + detector + ' is ' + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs)**2) ** (0.5)
print('network SNR is', networkSNR)
print("SNR of detector " + detector + " is " + str(SNR))
networkSNR = jnp.sum(jnp.array(self.SNRs) ** 2) ** (0.5)
print("network SNR is", networkSNR)
Loading

0 comments on commit 15a103d

Please sign in to comment.