From 7afa8b681f6dafce7d5bad44fa7de7be759602c5 Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Fri, 30 Aug 2024 12:55:09 +0800 Subject: [PATCH 01/11] add population likelihood and model --- src/jimgw/population/population_analysis.py | 169 ++++++++++++++++++ src/jimgw/population/population_likelihood.py | 21 +++ src/jimgw/population/population_model.py | 39 ++++ src/jimgw/population/utils.py | 15 ++ 4 files changed, 244 insertions(+) create mode 100644 src/jimgw/population/population_analysis.py create mode 100644 src/jimgw/population/population_likelihood.py create mode 100644 src/jimgw/population/population_model.py create mode 100644 src/jimgw/population/utils.py diff --git a/src/jimgw/population/population_analysis.py b/src/jimgw/population/population_analysis.py new file mode 100644 index 00000000..44ac7ebe --- /dev/null +++ b/src/jimgw/population/population_analysis.py @@ -0,0 +1,169 @@ +import argparse +import pandas as pd +import numpy as np +import jax +import jax.numpy as jnp +import glob +from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline +from flowMC.sampler.MALA import MALA +from flowMC.sampler.Sampler import Sampler +from flowMC.utils.PRNG_keys import initialize_rng_keys +import corner +from jimgw.population.population_likelihood import PopulationLikelihood +from jimgw.population.utils import create_model + + +def parse_args(): + parser = argparse.ArgumentParser(description='Run population likelihood sampling.') + parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') + parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.') + return parser.parse_args() + +def mtotal_from_mchirp_eta(mchirp, eta): + """Returns the total mass from the chirp mass and symmetric mass ratio.""" + return mchirp / eta**(3./5.) + +def mass1_from_mtotal_eta(mtotal, eta): + """Returns the primary mass from the total mass and symmetric mass ratio.""" + return 0.5 * mtotal * (1.0 + (1.0 - 4.0 * eta)**0.5) + +def mass1_from_mchirp_eta(mchirp, eta): + """Returns the primary mass from the chirp mass and symmetric mass ratio.""" + mtotal = mtotal_from_mchirp_eta(mchirp, eta) + return mass1_from_mtotal_eta(mtotal, eta) + +def prior_alpha(alpha): + return jax.lax.cond(alpha > 0, lambda: 0.0, lambda: -jnp.inf) + +def prior_x_min_x_max(x_min, x_max): + cond_1 = (x_max > x_min) + cond_2 = (x_min >= 5) & (x_min <= 20) + cond_3 = (x_max >= 50) & (x_max <= 100) + + return jax.lax.cond(cond_1 & cond_2 & cond_3, lambda: 0.0, lambda: -jnp.inf) + +def main(): + # Parse command-line arguments + args = parse_args() + + # For sampling events + directory = args.data_dir # Use the data directory from command-line argument + key = jax.random.PRNGKey(42) + mass_result_dict = [] + npz_files = glob.glob(directory + '/*.npz') + + num_files_to_sample = 100 + key, subkey = jax.random.split(key) + sample_indices = jax.random.choice(subkey, len(npz_files), shape=(num_files_to_sample,), replace=False) + sampled_npz_files = [npz_files[i] for i in sample_indices] + + for npz_file in sampled_npz_files: + print("Loading file:", npz_file) + with np.load(npz_file, allow_pickle=True) as data: + chains = data['chains'] + reshaped_chains = chains.reshape(-1, 11) + event_df = pd.DataFrame(reshaped_chains, columns=[ + 'M_c', 'eta', 's1_z', 's2_z', 'd_L', 't_c', 'phase_c', + 'iota', 'psi', 'ra', 'dec' + ]) + + # Randomly sample rows within each file in a reproducible manner + key, subkey = jax.random.split(key) + sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(5000,), replace=False) + sampled_df = event_df.iloc[sample_indices] + + # Extract M_c and eta using sampled indices + mc_sampled = sampled_df['M_c'].values + eta_sampled = sampled_df['eta'].values + + # Compute mass1 + mass1_sampled = mass1_from_mchirp_eta(mc_sampled, eta_sampled) + + # Append to the result dictionary + mass_array = jnp.array(mass1_sampled) + mass_result_dict.append(mass_array) + + # Stack all results into a single array + mass_array = jnp.stack(mass_result_dict) + + def pop_likelihood(pop_params ,data): + model = create_model(args.pop_model) + likelihood = PopulationLikelihood(mass_array, model, pop_params) + log_likelihood = likelihood.evaluate(mass_array, pop_params) + return log_likelihood + + + # def log_likelihood(pop_params, data): + # likelihood = PopulationLikelihood(mass_array,TruncatedPowerLawModel, pop_params) + # log_likelihood = likelihood.evaluate(mass_array, pop_params) + # return log_likelihood + + + n_dim = 3 + n_chains = 1000 + + rng_key = jax.random.PRNGKey(42) + + minval_0th_dim = 5 + maxval_0th_dim = 20 + + minval_1st_dim = 50 + maxval_1st_dim = 100 + + minval_2nd_dim = 0 + maxval_2nd_dim = 4 + + initial_positions = [] + + while len(initial_positions) < n_chains: + rng_key, subkey = jax.random.split(rng_key) + samples_0th_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_0th_dim, maxval=maxval_0th_dim) + rng_key, subkey = jax.random.split(rng_key) + samples_1st_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_1st_dim, maxval=maxval_1st_dim) + + valid_indices = jnp.where((samples_1st_dim >= samples_0th_dim))[0] + valid_positions = jnp.column_stack([samples_0th_dim[valid_indices], samples_1st_dim[valid_indices]]) + + remaining_chains_needed = n_chains - len(initial_positions) + if len(valid_positions) >= remaining_chains_needed: + valid_positions = valid_positions[:remaining_chains_needed] + + initial_positions.extend(valid_positions.tolist()) + + positions = jnp.column_stack([ + jnp.array(initial_positions), + jax.random.uniform(rng_key, shape=(n_chains,), minval=minval_2nd_dim, maxval=maxval_2nd_dim) + ]) + + model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) + + step_size = 1 + MALA_Sampler = MALA(pop_likelihood, True, {"step_size": step_size}) + + rng_key_set = initialize_rng_keys(n_chains, seed=42) + + nf_sampler = Sampler(n_dim, + rng_key_set, + pop_likelihood, + MALA_Sampler, + model, + n_local_steps=1000, + n_global_steps=1000, + n_epochs=30, + learning_rate=1e-3, + batch_size=1000, + n_chains=n_chains, + use_global=True) + + nf_sampler.sample(positions, data=None) + chains, log_prob, local_accs, global_accs = nf_sampler.get_sampler_state().values() + + corner.corner(np.array(chains.reshape(-1, n_dim))).savefig("corner.png") + # np.savez("pop_chains/pop_chain.npz", chains=chains, log_prob=log_prob, local_accs=local_accs, global_accs=global_accs) + print("local:", local_accs) + print("global:", global_accs) + print("chains:", chains) + print("log_prob:", log_prob) + +if __name__ == "__main__": + main() \ No newline at end of file diff --git a/src/jimgw/population/population_likelihood.py b/src/jimgw/population/population_likelihood.py new file mode 100644 index 00000000..8c830a41 --- /dev/null +++ b/src/jimgw/population/population_likelihood.py @@ -0,0 +1,21 @@ +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.base import LikelihoodBase + +class PopulationLikelihood(LikelihoodBase): + def __init__(self, mass_array, model_class, pop_params): + self.mass_array = mass_array + self.population_model = model_class(*pop_params) + + def evaluate(self, posteriors: dict, pop_params: dict) -> Float: + model_output = self.population_model.evaluate(posteriors, pop_params) + log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1))) + return log_likelihood + + + + + + \ No newline at end of file diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py new file mode 100644 index 00000000..a94d7dce --- /dev/null +++ b/src/jimgw/population/population_model.py @@ -0,0 +1,39 @@ +from abc import ABC, abstractmethod +import jax.numpy as jnp +from jaxtyping import Float + +class PopulationModelBase(ABC): + @abstractmethod + def __init__(self, *params): + self.params = params + + @abstractmethod + def evaluate(self,data: dict, pop_params: dict) -> Float: + """ + Evaluate the likelihood for a given set of parameters. + """ + raise NotImplementedError + +class TruncatedPowerLawModel(PopulationModelBase): + def __init__(self, *params): + super().__init__(*params) + + def truncated_power_law(self, x, x_min, x_max, alpha): + valid_indices = (x >= x_min) & (x <= x_max) + C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha)) + + # Ensure x is treated properly and avoid non-concrete indexing + pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x + pdf = jnp.where(valid_indices, C / (x ** alpha), pdf) + + return pdf + + def evaluate(self,data: dict, pop_params: dict) -> Float: + """Evaluate the truncated power law model with dynamic parameters.""" + x_min = pop_params[0] + x_max = pop_params[1] + alpha = pop_params[2] + + return self.truncated_power_law(data, x_min, x_max, alpha) + + diff --git a/src/jimgw/population/utils.py b/src/jimgw/population/utils.py new file mode 100644 index 00000000..3816ebcd --- /dev/null +++ b/src/jimgw/population/utils.py @@ -0,0 +1,15 @@ +import importlib + +def create_model(model_name): + try: + module = importlib.import_module('population_model') + + # Check if model_name is a string + if not isinstance(model_name, str): + raise ValueError("model_name must be a string") + + model_class = getattr(module, model_name) + return model_class + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import model '{model_name}': {str(e)}") + \ No newline at end of file From d83805c9b85524f15c31adb53c4482b696000fc7 Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Sat, 31 Aug 2024 23:55:21 +0800 Subject: [PATCH 02/11] adding __init__.py and the test script --- src/jimgw/population/__init__.py | 1 + .../population/population_analysis_test.py | 130 ++++++++++++++++++ 2 files changed, 131 insertions(+) create mode 100644 src/jimgw/population/__init__.py create mode 100644 src/jimgw/population/population_analysis_test.py diff --git a/src/jimgw/population/__init__.py b/src/jimgw/population/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/jimgw/population/__init__.py @@ -0,0 +1 @@ + diff --git a/src/jimgw/population/population_analysis_test.py b/src/jimgw/population/population_analysis_test.py new file mode 100644 index 00000000..39e151d7 --- /dev/null +++ b/src/jimgw/population/population_analysis_test.py @@ -0,0 +1,130 @@ +import argparse +import numpy as np +import jax +import jax.numpy as jnp +from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline +from flowMC import Sampler +from flowMC.proposal.MALA import MALA +import corner +from jimgw.population.population_likelihood import PopulationLikelihood +from jimgw.population.utils import create_model + +def parse_args(): + parser = argparse.ArgumentParser(description='Run population likelihood sampling.') + parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') + return parser.parse_args() + +def mtotal_from_mchirp_eta(mchirp, eta): + """Returns the total mass from the chirp mass and symmetric mass ratio.""" + return mchirp / eta**(3./5.) + +def mass1_from_mtotal_eta(mtotal, eta): + """Returns the primary mass from the total mass and symmetric mass ratio.""" + return 0.5 * mtotal * (1.0 + (1.0 - 4.0 * eta)**0.5) + +def mass1_from_mchirp_eta(mchirp, eta): + """Returns the primary mass from the chirp mass and symmetric mass ratio.""" + mtotal = mtotal_from_mchirp_eta(mchirp, eta) + return mass1_from_mtotal_eta(mtotal, eta) + +def prior_alpha(alpha): + return jax.lax.cond(alpha > 0, lambda: 0.0, lambda: -jnp.inf) + +def prior_x_min_x_max(x_min, x_max): + cond_1 = (x_max > x_min) + cond_2 = (x_min >= 5) & (x_min <= 20) + cond_3 = (x_max >= 50) & (x_max <= 100) + + return jax.lax.cond(cond_1 & cond_2 & cond_3, lambda: 0.0, lambda: -jnp.inf) + +def main(): + # Parse command-line arguments + args = parse_args() + + # Randomly generate mass arrays for population analysis + num_samples = 5000 # Number of samples to generate + mass_c_samples = jax.random.uniform(jax.random.PRNGKey(0), shape=(num_samples,), minval=5, maxval=20) # M_c samples + eta_samples = jax.random.uniform(jax.random.PRNGKey(1), shape=(num_samples,), minval=0.1, maxval=0.25) # eta samples (0 < eta < 0.25) + + # Compute mass1 from generated M_c and eta samples + mass1_samples = mass1_from_mchirp_eta(mass_c_samples, eta_samples) + + # Convert to JAX array + mass_array = jnp.array(mass1_samples) + + def pop_likelihood(pop_params, data): + model = create_model(args.pop_model) + likelihood = PopulationLikelihood(mass_array, model, pop_params) + log_likelihood = likelihood.evaluate(mass_array, pop_params) + return log_likelihood + + n_dim = create_model(args.pop_model).get_pop_params_dimension() + n_chains = 1000 + + rng_key = jax.random.PRNGKey(42) + + minval_0th_dim = 5 + maxval_0th_dim = 20 + + minval_1st_dim = 50 + maxval_1st_dim = 100 + + minval_2nd_dim = 0 + maxval_2nd_dim = 4 + + initial_positions = [] + + while len(initial_positions) < n_chains: + rng_key, subkey = jax.random.split(rng_key) + samples_0th_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_0th_dim, maxval=maxval_0th_dim) + rng_key, subkey = jax.random.split(rng_key) + samples_1st_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_1st_dim, maxval=maxval_1st_dim) + + valid_indices = jnp.where((samples_1st_dim >= samples_0th_dim))[0] + valid_positions = jnp.column_stack([samples_0th_dim[valid_indices], samples_1st_dim[valid_indices]]) + + remaining_chains_needed = n_chains - len(initial_positions) + if len(valid_positions) >= remaining_chains_needed: + valid_positions = valid_positions[:remaining_chains_needed] + + initial_positions.extend(valid_positions.tolist()) + + positions = jnp.column_stack([ + jnp.array(initial_positions), + jax.random.uniform(rng_key, shape=(n_chains,), minval=minval_2nd_dim, maxval=maxval_2nd_dim) + ]) + + model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) + + step_size = 1 + MALA_Sampler = MALA(pop_likelihood, True, {"step_size": step_size}) + + rng_key, subkey = jax.random.split(jax.random.PRNGKey(42)) + initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 + + + nf_sampler = Sampler(n_dim, + subkey, + pop_likelihood, + MALA_Sampler, + model, + n_local_steps=1000, + n_global_steps=1000, + n_epochs=30, + learning_rate=1e-3, + batch_size=1000, + n_chains=n_chains, + use_global=True) + + nf_sampler.sample(positions, data=None) + chains, log_prob, local_accs, global_accs = nf_sampler.get_sampler_state().values() + + corner.corner(np.array(chains.reshape(-1, n_dim))).savefig("corner.png") + # np.savez("pop_chains/pop_chain.npz", chains=chains, log_prob=log_prob, local_accs=local_accs, global_accs=global_accs) + print("local:", local_accs) + print("global:", global_accs) + print("chains:", chains) + print("log_prob:", log_prob) + +if __name__ == "__main__": + main() \ No newline at end of file From 2704c40ca897b2588ec3b74c2b84e502c587a9d2 Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Mon, 2 Sep 2024 11:22:53 +0800 Subject: [PATCH 03/11] update population analysis script --- src/jimgw/population/population_analysis.py | 7 +- .../population/population_analysis_test.py | 75 ++----------------- 2 files changed, 9 insertions(+), 73 deletions(-) diff --git a/src/jimgw/population/population_analysis.py b/src/jimgw/population/population_analysis.py index 44ac7ebe..89ed230e 100644 --- a/src/jimgw/population/population_analysis.py +++ b/src/jimgw/population/population_analysis.py @@ -1,13 +1,12 @@ import argparse -import pandas as pd import numpy as np import jax import jax.numpy as jnp +import pandas as pd import glob from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline -from flowMC.sampler.MALA import MALA -from flowMC.sampler.Sampler import Sampler -from flowMC.utils.PRNG_keys import initialize_rng_keys +from flowMC import Sampler +from flowMC.proposal.MALA import MALA import corner from jimgw.population.population_likelihood import PopulationLikelihood from jimgw.population.utils import create_model diff --git a/src/jimgw/population/population_analysis_test.py b/src/jimgw/population/population_analysis_test.py index 39e151d7..50f29c1d 100644 --- a/src/jimgw/population/population_analysis_test.py +++ b/src/jimgw/population/population_analysis_test.py @@ -14,43 +14,11 @@ def parse_args(): parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') return parser.parse_args() -def mtotal_from_mchirp_eta(mchirp, eta): - """Returns the total mass from the chirp mass and symmetric mass ratio.""" - return mchirp / eta**(3./5.) - -def mass1_from_mtotal_eta(mtotal, eta): - """Returns the primary mass from the total mass and symmetric mass ratio.""" - return 0.5 * mtotal * (1.0 + (1.0 - 4.0 * eta)**0.5) - -def mass1_from_mchirp_eta(mchirp, eta): - """Returns the primary mass from the chirp mass and symmetric mass ratio.""" - mtotal = mtotal_from_mchirp_eta(mchirp, eta) - return mass1_from_mtotal_eta(mtotal, eta) - -def prior_alpha(alpha): - return jax.lax.cond(alpha > 0, lambda: 0.0, lambda: -jnp.inf) - -def prior_x_min_x_max(x_min, x_max): - cond_1 = (x_max > x_min) - cond_2 = (x_min >= 5) & (x_min <= 20) - cond_3 = (x_max >= 50) & (x_max <= 100) - - return jax.lax.cond(cond_1 & cond_2 & cond_3, lambda: 0.0, lambda: -jnp.inf) - def main(): - # Parse command-line arguments args = parse_args() - - # Randomly generate mass arrays for population analysis - num_samples = 5000 # Number of samples to generate - mass_c_samples = jax.random.uniform(jax.random.PRNGKey(0), shape=(num_samples,), minval=5, maxval=20) # M_c samples - eta_samples = jax.random.uniform(jax.random.PRNGKey(1), shape=(num_samples,), minval=0.1, maxval=0.25) # eta samples (0 < eta < 0.25) - - # Compute mass1 from generated M_c and eta samples - mass1_samples = mass1_from_mchirp_eta(mass_c_samples, eta_samples) - - # Convert to JAX array - mass_array = jnp.array(mass1_samples) + num_samples = 500 + mass_samples = jax.random.uniform(jax.random.PRNGKey(0), shape=(num_samples,), minval=5, maxval=20) # M_c samples + mass_array = jnp.array(mass_samples) def pop_likelihood(pop_params, data): model = create_model(args.pop_model) @@ -59,48 +27,17 @@ def pop_likelihood(pop_params, data): return log_likelihood n_dim = create_model(args.pop_model).get_pop_params_dimension() - n_chains = 1000 - - rng_key = jax.random.PRNGKey(42) + n_chains = 10 - minval_0th_dim = 5 - maxval_0th_dim = 20 - - minval_1st_dim = 50 - maxval_1st_dim = 100 - - minval_2nd_dim = 0 - maxval_2nd_dim = 4 - - initial_positions = [] - - while len(initial_positions) < n_chains: - rng_key, subkey = jax.random.split(rng_key) - samples_0th_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_0th_dim, maxval=maxval_0th_dim) - rng_key, subkey = jax.random.split(rng_key) - samples_1st_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_1st_dim, maxval=maxval_1st_dim) - - valid_indices = jnp.where((samples_1st_dim >= samples_0th_dim))[0] - valid_positions = jnp.column_stack([samples_0th_dim[valid_indices], samples_1st_dim[valid_indices]]) - - remaining_chains_needed = n_chains - len(initial_positions) - if len(valid_positions) >= remaining_chains_needed: - valid_positions = valid_positions[:remaining_chains_needed] - - initial_positions.extend(valid_positions.tolist()) - - positions = jnp.column_stack([ - jnp.array(initial_positions), - jax.random.uniform(rng_key, shape=(n_chains,), minval=minval_2nd_dim, maxval=maxval_2nd_dim) - ]) model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) step_size = 1 MALA_Sampler = MALA(pop_likelihood, True, {"step_size": step_size}) + rng_key, subkey = jax.random.split(jax.random.PRNGKey(42)) - initial_position = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 + positions = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 nf_sampler = Sampler(n_dim, From a6ebf383f7b9abd240e322194f61eb6a9c62a1af Mon Sep 17 00:00:00 2001 From: Wong Shee Man Date: Mon, 2 Sep 2024 23:17:47 +0800 Subject: [PATCH 04/11] Add function for converting chains --- src/jimgw/population/population_analysis.py | 87 ++---------------- src/jimgw/population/utils.py | 99 ++++++++++++++++++++- 2 files changed, 107 insertions(+), 79 deletions(-) diff --git a/src/jimgw/population/population_analysis.py b/src/jimgw/population/population_analysis.py index 89ed230e..b9ba0092 100644 --- a/src/jimgw/population/population_analysis.py +++ b/src/jimgw/population/population_analysis.py @@ -5,11 +5,11 @@ import pandas as pd import glob from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline -from flowMC import Sampler +from flowMC.Sampler import Sampler from flowMC.proposal.MALA import MALA import corner from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model +from jimgw.population.utils import create_model, extract_data_from_npz_files, extract_data_from_npz_files_m1_m2 def parse_args(): @@ -18,88 +18,19 @@ def parse_args(): parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.') return parser.parse_args() -def mtotal_from_mchirp_eta(mchirp, eta): - """Returns the total mass from the chirp mass and symmetric mass ratio.""" - return mchirp / eta**(3./5.) - -def mass1_from_mtotal_eta(mtotal, eta): - """Returns the primary mass from the total mass and symmetric mass ratio.""" - return 0.5 * mtotal * (1.0 + (1.0 - 4.0 * eta)**0.5) - -def mass1_from_mchirp_eta(mchirp, eta): - """Returns the primary mass from the chirp mass and symmetric mass ratio.""" - mtotal = mtotal_from_mchirp_eta(mchirp, eta) - return mass1_from_mtotal_eta(mtotal, eta) - -def prior_alpha(alpha): - return jax.lax.cond(alpha > 0, lambda: 0.0, lambda: -jnp.inf) - -def prior_x_min_x_max(x_min, x_max): - cond_1 = (x_max > x_min) - cond_2 = (x_min >= 5) & (x_min <= 20) - cond_3 = (x_max >= 50) & (x_max <= 100) - - return jax.lax.cond(cond_1 & cond_2 & cond_3, lambda: 0.0, lambda: -jnp.inf) - def main(): - # Parse command-line arguments args = parse_args() - - # For sampling events - directory = args.data_dir # Use the data directory from command-line argument - key = jax.random.PRNGKey(42) - mass_result_dict = [] - npz_files = glob.glob(directory + '/*.npz') - - num_files_to_sample = 100 - key, subkey = jax.random.split(key) - sample_indices = jax.random.choice(subkey, len(npz_files), shape=(num_files_to_sample,), replace=False) - sampled_npz_files = [npz_files[i] for i in sample_indices] - - for npz_file in sampled_npz_files: - print("Loading file:", npz_file) - with np.load(npz_file, allow_pickle=True) as data: - chains = data['chains'] - reshaped_chains = chains.reshape(-1, 11) - event_df = pd.DataFrame(reshaped_chains, columns=[ - 'M_c', 'eta', 's1_z', 's2_z', 'd_L', 't_c', 'phase_c', - 'iota', 'psi', 'ra', 'dec' - ]) - - # Randomly sample rows within each file in a reproducible manner - key, subkey = jax.random.split(key) - sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(5000,), replace=False) - sampled_df = event_df.iloc[sample_indices] - - # Extract M_c and eta using sampled indices - mc_sampled = sampled_df['M_c'].values - eta_sampled = sampled_df['eta'].values - - # Compute mass1 - mass1_sampled = mass1_from_mchirp_eta(mc_sampled, eta_sampled) - - # Append to the result dictionary - mass_array = jnp.array(mass1_sampled) - mass_result_dict.append(mass_array) - - # Stack all results into a single array - mass_array = jnp.stack(mass_result_dict) + mass1_array, mass2_array = extract_data_from_npz_files_m1_m2(args.data_dir, num_samples=5000) def pop_likelihood(pop_params ,data): model = create_model(args.pop_model) - likelihood = PopulationLikelihood(mass_array, model, pop_params) - log_likelihood = likelihood.evaluate(mass_array, pop_params) + likelihood = PopulationLikelihood(mass1_array, model, pop_params) + log_likelihood = likelihood.evaluate(mass1_array, pop_params) return log_likelihood - - # def log_likelihood(pop_params, data): - # likelihood = PopulationLikelihood(mass_array,TruncatedPowerLawModel, pop_params) - # log_likelihood = likelihood.evaluate(mass_array, pop_params) - # return log_likelihood - n_dim = 3 - n_chains = 1000 + n_chains = 10 rng_key = jax.random.PRNGKey(42) @@ -137,12 +68,12 @@ def pop_likelihood(pop_params ,data): model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) step_size = 1 - MALA_Sampler = MALA(pop_likelihood, True, {"step_size": step_size}) + MALA_Sampler = MALA(pop_likelihood, True, step_size= step_size) - rng_key_set = initialize_rng_keys(n_chains, seed=42) + rng_key, subkey = jax.random.split(jax.random.PRNGKey(42)) nf_sampler = Sampler(n_dim, - rng_key_set, + subkey, pop_likelihood, MALA_Sampler, model, diff --git a/src/jimgw/population/utils.py b/src/jimgw/population/utils.py index 3816ebcd..01e0433e 100644 --- a/src/jimgw/population/utils.py +++ b/src/jimgw/population/utils.py @@ -1,4 +1,9 @@ import importlib +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +from jimgw.single_event.utils import Mc_eta_to_m1_m2 def create_model(model_name): try: @@ -12,4 +17,96 @@ def create_model(model_name): return model_class except (ImportError, AttributeError) as e: raise ImportError(f"Could not import model '{model_name}': {str(e)}") - \ No newline at end of file + + +def extract_data_from_npz_files(npz_files, column_name, num_samples=5000, random_seed=42): + """ + Extracts specified column data from the given .npz files. + + Parameters: + - npz_files (list of str): List of paths to .npz files. + - column_name (str): The name of the column to extract from the DataFrame. + - num_samples (int): Number of samples to extract from each file. + - random_seed (int): Seed for random number generation. + + Returns: + - jnp.array: Stacked array of extracted data. + """ + + key = jax.random.PRNGKey(random_seed) + result_dict = [] + + for npz_file in npz_files: + print("Loading file:", npz_file) + with np.load(npz_file, allow_pickle=True) as data: + chains = data['chains'] + reshaped_chains = chains.reshape(-1, 11) + event_df = pd.DataFrame(reshaped_chains) + + # Check if the specified column exists + if column_name not in event_df.columns: + raise ValueError(f"Column '{column_name}' not found in the data.") + + key, subkey = jax.random.split(key) + sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(num_samples,), replace=True) + sampled_df = event_df.iloc[sample_indices] + + extracted_data = sampled_df[column_name].values + + data_array = jnp.array(extracted_data) + result_dict.append(data_array) + + stacked_array = jnp.stack(result_dict) + return stacked_array + +def extract_data_from_npz_files_m1_m2(npz_files, num_samples=5000, random_seed=42): + """ + Extracts specified column data from the given .npz files and computes masses. + + Parameters + - npz_files (list of str): List of paths to .npz files. + - num_samples (int): Number of samples to extract from each file. + - random_seed (int): Seed for random number generation. + + Returns + - m1_array (jnp.array): Stacked array of primary masses. + - m2_array (jnp.array): Stacked array of secondary masses. + """ + + key = jax.random.PRNGKey(random_seed) + m1_results = [] + m2_results = [] + + for npz_file in npz_files: + print("Loading file:", npz_file) + with np.load(npz_file, allow_pickle=True) as data: + chains = data['chains'] + reshaped_chains = chains.reshape(-1, 11) + event_df = pd.DataFrame(reshaped_chains) + + # Check if the specified columns exist + if 'M_c' not in event_df.columns: + raise ValueError(f" M_c not found in the data.") + if 'eta' not in event_df.columns: + raise ValueError(f"Eta not found in the data.") + + key, subkey = jax.random.split(key) + sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(num_samples,), replace=True) + sampled_df = event_df.iloc[sample_indices] + + # Extract M_c and eta + M_c_sampled = sampled_df[M_c_column].values + eta_sampled = sampled_df[eta_column].values + + # Transform M_c and eta to m1 and m2 + m1_sampled, m2_sampled = Mc_eta_to_m1_m2(M_c_sampled, eta_sampled) + + # Convert to jax arrays and append to results + m1_results.append(jnp.array(m1_sampled)) + m2_results.append(jnp.array(m2_sampled)) + + # Stack all results into single arrays + m1_array = jnp.stack(m1_results) + m2_array = jnp.stack(m2_results) + + return m1_array, m2_array From d782833493de76417c5986969f4a434d344904ca Mon Sep 17 00:00:00 2001 From: Wong Shee Man Date: Wed, 4 Sep 2024 17:15:31 +0800 Subject: [PATCH 05/11] delete test script --- .../population/population_analysis_test.py | 67 ------------------- 1 file changed, 67 deletions(-) delete mode 100644 src/jimgw/population/population_analysis_test.py diff --git a/src/jimgw/population/population_analysis_test.py b/src/jimgw/population/population_analysis_test.py deleted file mode 100644 index 50f29c1d..00000000 --- a/src/jimgw/population/population_analysis_test.py +++ /dev/null @@ -1,67 +0,0 @@ -import argparse -import numpy as np -import jax -import jax.numpy as jnp -from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline -from flowMC import Sampler -from flowMC.proposal.MALA import MALA -import corner -from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model - -def parse_args(): - parser = argparse.ArgumentParser(description='Run population likelihood sampling.') - parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') - return parser.parse_args() - -def main(): - args = parse_args() - num_samples = 500 - mass_samples = jax.random.uniform(jax.random.PRNGKey(0), shape=(num_samples,), minval=5, maxval=20) # M_c samples - mass_array = jnp.array(mass_samples) - - def pop_likelihood(pop_params, data): - model = create_model(args.pop_model) - likelihood = PopulationLikelihood(mass_array, model, pop_params) - log_likelihood = likelihood.evaluate(mass_array, pop_params) - return log_likelihood - - n_dim = create_model(args.pop_model).get_pop_params_dimension() - n_chains = 10 - - - model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) - - step_size = 1 - MALA_Sampler = MALA(pop_likelihood, True, {"step_size": step_size}) - - - rng_key, subkey = jax.random.split(jax.random.PRNGKey(42)) - positions = jax.random.normal(subkey, shape=(n_chains, n_dim)) * 1 - - - nf_sampler = Sampler(n_dim, - subkey, - pop_likelihood, - MALA_Sampler, - model, - n_local_steps=1000, - n_global_steps=1000, - n_epochs=30, - learning_rate=1e-3, - batch_size=1000, - n_chains=n_chains, - use_global=True) - - nf_sampler.sample(positions, data=None) - chains, log_prob, local_accs, global_accs = nf_sampler.get_sampler_state().values() - - corner.corner(np.array(chains.reshape(-1, n_dim))).savefig("corner.png") - # np.savez("pop_chains/pop_chain.npz", chains=chains, log_prob=log_prob, local_accs=local_accs, global_accs=global_accs) - print("local:", local_accs) - print("global:", global_accs) - print("chains:", chains) - print("log_prob:", log_prob) - -if __name__ == "__main__": - main() \ No newline at end of file From 7d4b0dc975d268e034f3629cd79608a7165484ad Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Sat, 7 Sep 2024 23:27:58 +0800 Subject: [PATCH 06/11] Updating the loading of chains as dict --- src/jimgw/population/population_analysis.py | 10 +-- src/jimgw/population/utils.py | 91 +++++---------------- 2 files changed, 25 insertions(+), 76 deletions(-) diff --git a/src/jimgw/population/population_analysis.py b/src/jimgw/population/population_analysis.py index b9ba0092..c224ab91 100644 --- a/src/jimgw/population/population_analysis.py +++ b/src/jimgw/population/population_analysis.py @@ -2,14 +2,12 @@ import numpy as np import jax import jax.numpy as jnp -import pandas as pd -import glob from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline from flowMC.Sampler import Sampler from flowMC.proposal.MALA import MALA import corner from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model, extract_data_from_npz_files, extract_data_from_npz_files_m1_m2 +from jimgw.population.utils import create_model, extract_data_from_npz_files def parse_args(): @@ -20,7 +18,7 @@ def parse_args(): def main(): args = parse_args() - mass1_array, mass2_array = extract_data_from_npz_files_m1_m2(args.data_dir, num_samples=5000) + mass1_array = extract_data_from_npz_files(args.data_dir,"m_1", num_samples=5000, random_seed=42) def pop_likelihood(pop_params ,data): model = create_model(args.pop_model) @@ -77,8 +75,8 @@ def pop_likelihood(pop_params ,data): pop_likelihood, MALA_Sampler, model, - n_local_steps=1000, - n_global_steps=1000, + n_local_steps=10, + n_global_steps=10, n_epochs=30, learning_rate=1e-3, batch_size=1000, diff --git a/src/jimgw/population/utils.py b/src/jimgw/population/utils.py index 01e0433e..6ac29f39 100644 --- a/src/jimgw/population/utils.py +++ b/src/jimgw/population/utils.py @@ -4,6 +4,7 @@ import numpy as np import pandas as pd from jimgw.single_event.utils import Mc_eta_to_m1_m2 +import glob def create_model(model_name): try: @@ -19,12 +20,12 @@ def create_model(model_name): raise ImportError(f"Could not import model '{model_name}': {str(e)}") -def extract_data_from_npz_files(npz_files, column_name, num_samples=5000, random_seed=42): +def extract_data_from_npz_files(data_dir, column_name, num_samples=50, random_seed=42): """ Extracts specified column data from the given .npz files. Parameters: - - npz_files (list of str): List of paths to .npz files. + - data_dir (str): The directory containing all the .npz files. - column_name (str): The name of the column to extract from the DataFrame. - num_samples (int): Number of samples to extract from each file. - random_seed (int): Seed for random number generation. @@ -32,81 +33,31 @@ def extract_data_from_npz_files(npz_files, column_name, num_samples=5000, random Returns: - jnp.array: Stacked array of extracted data. """ - + + npz_files = glob.glob(f"{data_dir}/*.npz") key = jax.random.PRNGKey(random_seed) - result_dict = [] + result_list = [] for npz_file in npz_files: - print("Loading file:", npz_file) + print(f"Loading file: {npz_file}") + with np.load(npz_file, allow_pickle=True) as data: - chains = data['chains'] - reshaped_chains = chains.reshape(-1, 11) - event_df = pd.DataFrame(reshaped_chains) - - # Check if the specified column exists - if column_name not in event_df.columns: + data_dict = data['arr_0'].item() + if column_name not in data_dict: raise ValueError(f"Column '{column_name}' not found in the data.") - - key, subkey = jax.random.split(key) - sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(num_samples,), replace=True) - sampled_df = event_df.iloc[sample_indices] - - extracted_data = sampled_df[column_name].values - data_array = jnp.array(extracted_data) - result_dict.append(data_array) + extracted_data = data_dict[column_name].reshape(-1,) + print(extracted_data) + print(extracted_data.shape) - stacked_array = jnp.stack(result_dict) - return stacked_array - -def extract_data_from_npz_files_m1_m2(npz_files, num_samples=5000, random_seed=42): - """ - Extracts specified column data from the given .npz files and computes masses. - - Parameters - - npz_files (list of str): List of paths to .npz files. - - num_samples (int): Number of samples to extract from each file. - - random_seed (int): Seed for random number generation. - - Returns - - m1_array (jnp.array): Stacked array of primary masses. - - m2_array (jnp.array): Stacked array of secondary masses. - """ - - key = jax.random.PRNGKey(random_seed) - m1_results = [] - m2_results = [] - - for npz_file in npz_files: - print("Loading file:", npz_file) - with np.load(npz_file, allow_pickle=True) as data: - chains = data['chains'] - reshaped_chains = chains.reshape(-1, 11) - event_df = pd.DataFrame(reshaped_chains) - - # Check if the specified columns exist - if 'M_c' not in event_df.columns: - raise ValueError(f" M_c not found in the data.") - if 'eta' not in event_df.columns: - raise ValueError(f"Eta not found in the data.") + if isinstance(extracted_data, np.ndarray): + extracted_data = jax.device_put(extracted_data) key, subkey = jax.random.split(key) - sample_indices = jax.random.choice(subkey, event_df.shape[0], shape=(num_samples,), replace=True) - sampled_df = event_df.iloc[sample_indices] - - # Extract M_c and eta - M_c_sampled = sampled_df[M_c_column].values - eta_sampled = sampled_df[eta_column].values - - # Transform M_c and eta to m1 and m2 - m1_sampled, m2_sampled = Mc_eta_to_m1_m2(M_c_sampled, eta_sampled) + sample_indices = jax.random.choice(subkey, extracted_data.shape[0], shape=(num_samples,), replace=True) - # Convert to jax arrays and append to results - m1_results.append(jnp.array(m1_sampled)) - m2_results.append(jnp.array(m2_sampled)) - - # Stack all results into single arrays - m1_array = jnp.stack(m1_results) - m2_array = jnp.stack(m2_results) - - return m1_array, m2_array + sampled_data = extracted_data[sample_indices] + result_list.append(sampled_data) + stacked_array = jnp.stack(result_list) + + return stacked_array \ No newline at end of file From 5347df372b907f472c103e9a97198c872662870d Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Sun, 8 Sep 2024 00:13:48 +0800 Subject: [PATCH 07/11] try to implement jim for population --- src/jimgw/population/example_population.py | 90 ++++++++++++++++++++++ src/jimgw/population/transform.py | 21 +++++ 2 files changed, 111 insertions(+) create mode 100644 src/jimgw/population/example_population.py create mode 100644 src/jimgw/population/transform.py diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py new file mode 100644 index 00000000..45349881 --- /dev/null +++ b/src/jimgw/population/example_population.py @@ -0,0 +1,90 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from flowMC.strategy.optimization import optimization_Adam +from jimgw.population.population_likelihood import PopulationLikelihood +from jimgw.population.utils import create_model, extract_data_from_npz_files +import argparse +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.transforms import BoundToUnbound +from jimgw.population.transform import NullTransform + +jax.config.update("jax_enable_x64", True) + +def parse_args(): + parser = argparse.ArgumentParser(description='Run population likelihood sampling.') + parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') + parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.') + return parser.parse_args() + +def main(): + args = parse_args() + mass1_array = extract_data_from_npz_files(args.data_dir,"m_1", num_samples=5000, random_seed=42) + + """ + need changes for the pop_likelihood + """ + def pop_likelihood(pop_params ,data): + model = create_model(args.pop_model) + likelihood = PopulationLikelihood(mass1_array, model, pop_params) + log_likelihood = likelihood.evaluate(mass1_array, pop_params) + return log_likelihood + + mass_matrix = jnp.eye(11) + mass_matrix = mass_matrix.at[1, 1].set(1e-3) + mass_matrix = mass_matrix.at[5, 5].set(1e-3) + local_sampler_arg = {"step_size": mass_matrix * 3e-3} + + Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + + """" + The following needs changing + """ + m_min_prior = UniformPrior(10.,80.,parameter_names = ["m_min"]) + m_max_prior = UniformPrior(10.,80.,parameter_names = ["m_max"]) + alpha_prior = UniformPrior(0.,10.,parameter_names = ["alpha"]) + prior = CombinePrior([m_min_prior, m_max_prior, alpha_prior]) + sample_transforms = [BoundToUnbound(name_mapping = [["m_min"], ["m_min_unbounded"]], original_lower_bound=10, original_upper_bound=80), + BoundToUnbound(name_mapping = [["m_max"], ["m_max_unbounded"]], original_lower_bound=10, original_upper_bound=80), + BoundToUnbound(name_mapping = [["alpha"], ["alpha_unbounded"]], original_lower_bound=0, original_upper_bound =10)] + name_mapping = ( + ["m_min", "m_max", "alpha"], + ["m_min", "m_max", "alpha"] + ) + likelihood_transforms = [NullTransform(name_mapping)] + + n_epochs = 2 + n_loop_training = 1 + learning_rate = 1e-4 + + + jim = Jim( + pop_likelihood, + prior, + sample_transforms, + likelihood_transforms , + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], + ) + + jim.sample(jax.random.PRNGKey(42)) + samples =jim.get_samples() + jim.print_summary() + +if __name__ == "__main__": + main() diff --git a/src/jimgw/population/transform.py b/src/jimgw/population/transform.py new file mode 100644 index 00000000..9b159ce2 --- /dev/null +++ b/src/jimgw/population/transform.py @@ -0,0 +1,21 @@ +from typing import Tuple, List, Dict +from jimgw.transforms import NtoNTransform, Float + + +class NullTransform(NtoNTransform): + """ + Null transformation that does nothing to the input data. + """ + + def __init__(self, name_mapping: Tuple[List[str], List[str]]): + super().__init__(name_mapping) + + # Ensure that the input and output name mappings are the same length + if len(name_mapping[0]) != len(name_mapping[1]): + raise ValueError("Input and output name mappings must have the same length.") + + # The transform function simply returns the input as-is + def null_transform(x: Dict[str, Float]) -> Dict[str, Float]: + return {key: x[key] for key in name_mapping[0]} + + self.transform_func = null_transform \ No newline at end of file From 800f1275e3547466133cf3a52e65fa5ca40260e9 Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Mon, 9 Sep 2024 12:39:10 +0800 Subject: [PATCH 08/11] Update example --- src/jimgw/population/example_population.py | 16 +++++--- src/jimgw/population/population_likelihood.py | 7 ++-- src/jimgw/population/population_model.py | 38 +++++++++++++++---- 3 files changed, 45 insertions(+), 16 deletions(-) diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py index 45349881..44a008a8 100644 --- a/src/jimgw/population/example_population.py +++ b/src/jimgw/population/example_population.py @@ -25,12 +25,16 @@ def main(): """ need changes for the pop_likelihood """ - def pop_likelihood(pop_params ,data): - model = create_model(args.pop_model) - likelihood = PopulationLikelihood(mass1_array, model, pop_params) - log_likelihood = likelihood.evaluate(mass1_array, pop_params) - return log_likelihood - + # def pop_likelihood(pop_params ,data): + # model = create_model(args.pop_model) + # likelihood = PopulationLikelihood(mass1_array, model, pop_params) + # log_likelihood = likelihood.evaluate(mass1_array, pop_params) + # return log_likelihood + + model = create_model(args.pop_model) + pop_params = ["m_min",1,2] + pop_likelihood = PopulationLikelihood(mass1_array, model, pop_params) + mass_matrix = jnp.eye(11) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) diff --git a/src/jimgw/population/population_likelihood.py b/src/jimgw/population/population_likelihood.py index 8c830a41..1a61fde9 100644 --- a/src/jimgw/population/population_likelihood.py +++ b/src/jimgw/population/population_likelihood.py @@ -9,11 +9,12 @@ def __init__(self, mass_array, model_class, pop_params): self.mass_array = mass_array self.population_model = model_class(*pop_params) - def evaluate(self, posteriors: dict, pop_params: dict) -> Float: - model_output = self.population_model.evaluate(posteriors, pop_params) + def evaluate(self, pop_params: dict[str, Float],posteriors: dict) -> Float: + model_output = self.population_model.evaluate(pop_params, posteriors) log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1))) return log_likelihood - + + diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py index a94d7dce..a531d8a4 100644 --- a/src/jimgw/population/population_model.py +++ b/src/jimgw/population/population_model.py @@ -14,6 +14,28 @@ def evaluate(self,data: dict, pop_params: dict) -> Float: """ raise NotImplementedError +# class TruncatedPowerLawModel(PopulationModelBase): +# def __init__(self, *params): +# super().__init__(*params) + +# def truncated_power_law(self, x, x_min, x_max, alpha): +# valid_indices = (x >= x_min) & (x <= x_max) +# C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha)) + +# # Ensure x is treated properly and avoid non-concrete indexing +# pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x +# pdf = jnp.where(valid_indices, C / (x ** alpha), pdf) + +# return pdf + +# def evaluate(self,data: dict, pop_params: dict) -> Float: +# """Evaluate the truncated power law model with dynamic parameters.""" +# x_min = pop_params[0] +# x_max = pop_params[1] +# alpha = pop_params[2] + +# return self.truncated_power_law(data, x_min, x_max, alpha) + class TruncatedPowerLawModel(PopulationModelBase): def __init__(self, *params): super().__init__(*params) @@ -28,12 +50,14 @@ def truncated_power_law(self, x, x_min, x_max, alpha): return pdf - def evaluate(self,data: dict, pop_params: dict) -> Float: + def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: """Evaluate the truncated power law model with dynamic parameters.""" - x_min = pop_params[0] - x_max = pop_params[1] - alpha = pop_params[2] - - return self.truncated_power_law(data, x_min, x_max, alpha) - + print("pop_parmas",pop_params) + m_min = pop_params["m_min"] + m_max = pop_params["m_max"] + alpha = pop_params["alpha"] + return self.truncated_power_law(data, m_min, m_max, alpha) + + + From 29b42362c29ad4a8ea2f8c35572839b3c281d651 Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Mon, 9 Sep 2024 17:22:52 +0800 Subject: [PATCH 09/11] Update the example script. Pop analysis now works with jim --- src/jimgw/population/example_population.py | 20 +++------ src/jimgw/population/population_likelihood.py | 11 ++--- src/jimgw/population/population_model.py | 42 ++++--------------- src/jimgw/population/utils.py | 2 - 4 files changed, 19 insertions(+), 56 deletions(-) diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py index 44a008a8..b6ea942b 100644 --- a/src/jimgw/population/example_population.py +++ b/src/jimgw/population/example_population.py @@ -20,22 +20,10 @@ def parse_args(): def main(): args = parse_args() - mass1_array = extract_data_from_npz_files(args.data_dir,"m_1", num_samples=5000, random_seed=42) - - """ - need changes for the pop_likelihood - """ - # def pop_likelihood(pop_params ,data): - # model = create_model(args.pop_model) - # likelihood = PopulationLikelihood(mass1_array, model, pop_params) - # log_likelihood = likelihood.evaluate(mass1_array, pop_params) - # return log_likelihood - model = create_model(args.pop_model) - pop_params = ["m_min",1,2] - pop_likelihood = PopulationLikelihood(mass1_array, model, pop_params) + pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model) - mass_matrix = jnp.eye(11) + mass_matrix = jnp.eye(model.get_pop_params_dimension()) mass_matrix = mass_matrix.at[1, 1].set(1e-3) mass_matrix = mass_matrix.at[5, 5].set(1e-3) local_sampler_arg = {"step_size": mass_matrix * 3e-3} @@ -48,7 +36,9 @@ def main(): m_min_prior = UniformPrior(10.,80.,parameter_names = ["m_min"]) m_max_prior = UniformPrior(10.,80.,parameter_names = ["m_max"]) alpha_prior = UniformPrior(0.,10.,parameter_names = ["alpha"]) + prior = CombinePrior([m_min_prior, m_max_prior, alpha_prior]) + sample_transforms = [BoundToUnbound(name_mapping = [["m_min"], ["m_min_unbounded"]], original_lower_bound=10, original_upper_bound=80), BoundToUnbound(name_mapping = [["m_max"], ["m_max_unbounded"]], original_lower_bound=10, original_upper_bound=80), BoundToUnbound(name_mapping = [["alpha"], ["alpha_unbounded"]], original_lower_bound=0, original_upper_bound =10)] @@ -87,7 +77,7 @@ def main(): ) jim.sample(jax.random.PRNGKey(42)) - samples =jim.get_samples() + jim.get_samples() jim.print_summary() if __name__ == "__main__": diff --git a/src/jimgw/population/population_likelihood.py b/src/jimgw/population/population_likelihood.py index 1a61fde9..f88a196d 100644 --- a/src/jimgw/population/population_likelihood.py +++ b/src/jimgw/population/population_likelihood.py @@ -3,14 +3,15 @@ from jaxtyping import Float from jimgw.base import LikelihoodBase +from jimgw.population.utils import extract_data_from_npz_files class PopulationLikelihood(LikelihoodBase): - def __init__(self, mass_array, model_class, pop_params): - self.mass_array = mass_array - self.population_model = model_class(*pop_params) + def __init__(self, data_dir, column_name, num_samples, model_class): + self.posteriors = extract_data_from_npz_files(data_dir, column_name, num_samples, random_seed=42) + self.population_model = model_class() - def evaluate(self, pop_params: dict[str, Float],posteriors: dict) -> Float: - model_output = self.population_model.evaluate(pop_params, posteriors) + def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: + model_output = self.population_model.evaluate(pop_params, self.posteriors) log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1))) return log_likelihood diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py index a531d8a4..9f8b404d 100644 --- a/src/jimgw/population/population_model.py +++ b/src/jimgw/population/population_model.py @@ -8,55 +8,29 @@ def __init__(self, *params): self.params = params @abstractmethod - def evaluate(self,data: dict, pop_params: dict) -> Float: + def evaluate(self, pop_params: dict, data: dict) -> Float: """ Evaluate the likelihood for a given set of parameters. """ raise NotImplementedError -# class TruncatedPowerLawModel(PopulationModelBase): -# def __init__(self, *params): -# super().__init__(*params) - -# def truncated_power_law(self, x, x_min, x_max, alpha): -# valid_indices = (x >= x_min) & (x <= x_max) -# C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha)) - -# # Ensure x is treated properly and avoid non-concrete indexing -# pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x -# pdf = jnp.where(valid_indices, C / (x ** alpha), pdf) - -# return pdf - -# def evaluate(self,data: dict, pop_params: dict) -> Float: -# """Evaluate the truncated power law model with dynamic parameters.""" -# x_min = pop_params[0] -# x_max = pop_params[1] -# alpha = pop_params[2] - -# return self.truncated_power_law(data, x_min, x_max, alpha) class TruncatedPowerLawModel(PopulationModelBase): - def __init__(self, *params): - super().__init__(*params) + def __init__(self): + super().__init__() def truncated_power_law(self, x, x_min, x_max, alpha): valid_indices = (x >= x_min) & (x <= x_max) C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha)) - - # Ensure x is treated properly and avoid non-concrete indexing - pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x + pdf = jnp.zeros_like(x) pdf = jnp.where(valid_indices, C / (x ** alpha), pdf) - return pdf def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: - """Evaluate the truncated power law model with dynamic parameters.""" - print("pop_parmas",pop_params) - m_min = pop_params["m_min"] - m_max = pop_params["m_max"] - alpha = pop_params["alpha"] - return self.truncated_power_law(data, m_min, m_max, alpha) + return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"]) + + def get_pop_params_dimension(): + return 3 diff --git a/src/jimgw/population/utils.py b/src/jimgw/population/utils.py index 6ac29f39..7dc5b241 100644 --- a/src/jimgw/population/utils.py +++ b/src/jimgw/population/utils.py @@ -47,8 +47,6 @@ def extract_data_from_npz_files(data_dir, column_name, num_samples=50, random_se raise ValueError(f"Column '{column_name}' not found in the data.") extracted_data = data_dict[column_name].reshape(-1,) - print(extracted_data) - print(extracted_data.shape) if isinstance(extracted_data, np.ndarray): extracted_data = jax.device_put(extracted_data) From 669b21bcf9aadcd36b20c81fb66c7f1d0c4f58ed Mon Sep 17 00:00:00 2001 From: CharmaineW Date: Mon, 9 Sep 2024 21:39:51 +0800 Subject: [PATCH 10/11] Del mass_matrix --- src/jimgw/population/example_population.py | 7 ++----- src/jimgw/population/population_model.py | 2 -- 2 files changed, 2 insertions(+), 7 deletions(-) diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py index b6ea942b..6baa4c09 100644 --- a/src/jimgw/population/example_population.py +++ b/src/jimgw/population/example_population.py @@ -4,7 +4,7 @@ from jimgw.jim import Jim from flowMC.strategy.optimization import optimization_Adam from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model, extract_data_from_npz_files +from jimgw.population.utils import create_model import argparse from jimgw.prior import UniformPrior, CombinePrior from jimgw.transforms import BoundToUnbound @@ -23,10 +23,7 @@ def main(): model = create_model(args.pop_model) pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model) - mass_matrix = jnp.eye(model.get_pop_params_dimension()) - mass_matrix = mass_matrix.at[1, 1].set(1e-3) - mass_matrix = mass_matrix.at[5, 5].set(1e-3) - local_sampler_arg = {"step_size": mass_matrix * 3e-3} + local_sampler_arg = {"step_size": 3e-3} Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py index 9f8b404d..71ebad60 100644 --- a/src/jimgw/population/population_model.py +++ b/src/jimgw/population/population_model.py @@ -29,8 +29,6 @@ def truncated_power_law(self, x, x_min, x_max, alpha): def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"]) - def get_pop_params_dimension(): - return 3 From dd10ce3d6ddd2497c49a63dfffc24fb3555d04d5 Mon Sep 17 00:00:00 2001 From: Wong Shee Man Date: Tue, 10 Sep 2024 15:44:27 +0800 Subject: [PATCH 11/11] update script --- src/jimgw/population/population_analysis.py | 97 --------------------- 1 file changed, 97 deletions(-) delete mode 100644 src/jimgw/population/population_analysis.py diff --git a/src/jimgw/population/population_analysis.py b/src/jimgw/population/population_analysis.py deleted file mode 100644 index c224ab91..00000000 --- a/src/jimgw/population/population_analysis.py +++ /dev/null @@ -1,97 +0,0 @@ -import argparse -import numpy as np -import jax -import jax.numpy as jnp -from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline -from flowMC.Sampler import Sampler -from flowMC.proposal.MALA import MALA -import corner -from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model, extract_data_from_npz_files - - -def parse_args(): - parser = argparse.ArgumentParser(description='Run population likelihood sampling.') - parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') - parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.') - return parser.parse_args() - -def main(): - args = parse_args() - mass1_array = extract_data_from_npz_files(args.data_dir,"m_1", num_samples=5000, random_seed=42) - - def pop_likelihood(pop_params ,data): - model = create_model(args.pop_model) - likelihood = PopulationLikelihood(mass1_array, model, pop_params) - log_likelihood = likelihood.evaluate(mass1_array, pop_params) - return log_likelihood - - - n_dim = 3 - n_chains = 10 - - rng_key = jax.random.PRNGKey(42) - - minval_0th_dim = 5 - maxval_0th_dim = 20 - - minval_1st_dim = 50 - maxval_1st_dim = 100 - - minval_2nd_dim = 0 - maxval_2nd_dim = 4 - - initial_positions = [] - - while len(initial_positions) < n_chains: - rng_key, subkey = jax.random.split(rng_key) - samples_0th_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_0th_dim, maxval=maxval_0th_dim) - rng_key, subkey = jax.random.split(rng_key) - samples_1st_dim = jax.random.uniform(subkey, shape=(n_chains,), minval=minval_1st_dim, maxval=maxval_1st_dim) - - valid_indices = jnp.where((samples_1st_dim >= samples_0th_dim))[0] - valid_positions = jnp.column_stack([samples_0th_dim[valid_indices], samples_1st_dim[valid_indices]]) - - remaining_chains_needed = n_chains - len(initial_positions) - if len(valid_positions) >= remaining_chains_needed: - valid_positions = valid_positions[:remaining_chains_needed] - - initial_positions.extend(valid_positions.tolist()) - - positions = jnp.column_stack([ - jnp.array(initial_positions), - jax.random.uniform(rng_key, shape=(n_chains,), minval=minval_2nd_dim, maxval=maxval_2nd_dim) - ]) - - model = MaskedCouplingRQSpline(n_layers=3, hidden_size=[64, 64], num_bins=8, n_features=n_dim, key=jax.random.PRNGKey(0)) - - step_size = 1 - MALA_Sampler = MALA(pop_likelihood, True, step_size= step_size) - - rng_key, subkey = jax.random.split(jax.random.PRNGKey(42)) - - nf_sampler = Sampler(n_dim, - subkey, - pop_likelihood, - MALA_Sampler, - model, - n_local_steps=10, - n_global_steps=10, - n_epochs=30, - learning_rate=1e-3, - batch_size=1000, - n_chains=n_chains, - use_global=True) - - nf_sampler.sample(positions, data=None) - chains, log_prob, local_accs, global_accs = nf_sampler.get_sampler_state().values() - - corner.corner(np.array(chains.reshape(-1, n_dim))).savefig("corner.png") - # np.savez("pop_chains/pop_chain.npz", chains=chains, log_prob=log_prob, local_accs=local_accs, global_accs=global_accs) - print("local:", local_accs) - print("global:", global_accs) - print("chains:", chains) - print("log_prob:", log_prob) - -if __name__ == "__main__": - main() \ No newline at end of file