diff --git a/src/jimgw/population/__init__.py b/src/jimgw/population/__init__.py new file mode 100644 index 00000000..8b137891 --- /dev/null +++ b/src/jimgw/population/__init__.py @@ -0,0 +1 @@ + diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py new file mode 100644 index 00000000..6baa4c09 --- /dev/null +++ b/src/jimgw/population/example_population.py @@ -0,0 +1,81 @@ +import jax +import jax.numpy as jnp + +from jimgw.jim import Jim +from flowMC.strategy.optimization import optimization_Adam +from jimgw.population.population_likelihood import PopulationLikelihood +from jimgw.population.utils import create_model +import argparse +from jimgw.prior import UniformPrior, CombinePrior +from jimgw.transforms import BoundToUnbound +from jimgw.population.transform import NullTransform + +jax.config.update("jax_enable_x64", True) + +def parse_args(): + parser = argparse.ArgumentParser(description='Run population likelihood sampling.') + parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.') + parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.') + return parser.parse_args() + +def main(): + args = parse_args() + model = create_model(args.pop_model) + pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model) + + local_sampler_arg = {"step_size": 3e-3} + + Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) + + """" + The following needs changing + """ + m_min_prior = UniformPrior(10.,80.,parameter_names = ["m_min"]) + m_max_prior = UniformPrior(10.,80.,parameter_names = ["m_max"]) + alpha_prior = UniformPrior(0.,10.,parameter_names = ["alpha"]) + + prior = CombinePrior([m_min_prior, m_max_prior, alpha_prior]) + + sample_transforms = [BoundToUnbound(name_mapping = [["m_min"], ["m_min_unbounded"]], original_lower_bound=10, original_upper_bound=80), + BoundToUnbound(name_mapping = [["m_max"], ["m_max_unbounded"]], original_lower_bound=10, original_upper_bound=80), + BoundToUnbound(name_mapping = [["alpha"], ["alpha_unbounded"]], original_lower_bound=0, original_upper_bound =10)] + name_mapping = ( + ["m_min", "m_max", "alpha"], + ["m_min", "m_max", "alpha"] + ) + likelihood_transforms = [NullTransform(name_mapping)] + + n_epochs = 2 + n_loop_training = 1 + learning_rate = 1e-4 + + + jim = Jim( + pop_likelihood, + prior, + sample_transforms, + likelihood_transforms , + n_loop_training=n_loop_training, + n_loop_production=1, + n_local_steps=5, + n_global_steps=5, + n_chains=4, + n_epochs=n_epochs, + learning_rate=learning_rate, + n_max_examples=30, + n_flow_samples=100, + momentum=0.9, + batch_size=100, + use_global=True, + train_thinning=1, + output_thinning=1, + local_sampler_arg=local_sampler_arg, + strategies=[Adam_optimizer, "default"], + ) + + jim.sample(jax.random.PRNGKey(42)) + jim.get_samples() + jim.print_summary() + +if __name__ == "__main__": + main() diff --git a/src/jimgw/population/population_likelihood.py b/src/jimgw/population/population_likelihood.py new file mode 100644 index 00000000..f88a196d --- /dev/null +++ b/src/jimgw/population/population_likelihood.py @@ -0,0 +1,23 @@ +import jax +import jax.numpy as jnp +from jaxtyping import Float + +from jimgw.base import LikelihoodBase +from jimgw.population.utils import extract_data_from_npz_files + +class PopulationLikelihood(LikelihoodBase): + def __init__(self, data_dir, column_name, num_samples, model_class): + self.posteriors = extract_data_from_npz_files(data_dir, column_name, num_samples, random_seed=42) + self.population_model = model_class() + + def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: + model_output = self.population_model.evaluate(pop_params, self.posteriors) + log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1))) + return log_likelihood + + + + + + + \ No newline at end of file diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py new file mode 100644 index 00000000..71ebad60 --- /dev/null +++ b/src/jimgw/population/population_model.py @@ -0,0 +1,35 @@ +from abc import ABC, abstractmethod +import jax.numpy as jnp +from jaxtyping import Float + +class PopulationModelBase(ABC): + @abstractmethod + def __init__(self, *params): + self.params = params + + @abstractmethod + def evaluate(self, pop_params: dict, data: dict) -> Float: + """ + Evaluate the likelihood for a given set of parameters. + """ + raise NotImplementedError + + +class TruncatedPowerLawModel(PopulationModelBase): + def __init__(self): + super().__init__() + + def truncated_power_law(self, x, x_min, x_max, alpha): + valid_indices = (x >= x_min) & (x <= x_max) + C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha)) + pdf = jnp.zeros_like(x) + pdf = jnp.where(valid_indices, C / (x ** alpha), pdf) + return pdf + + def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: + return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"]) + + + + + diff --git a/src/jimgw/population/transform.py b/src/jimgw/population/transform.py new file mode 100644 index 00000000..9b159ce2 --- /dev/null +++ b/src/jimgw/population/transform.py @@ -0,0 +1,21 @@ +from typing import Tuple, List, Dict +from jimgw.transforms import NtoNTransform, Float + + +class NullTransform(NtoNTransform): + """ + Null transformation that does nothing to the input data. + """ + + def __init__(self, name_mapping: Tuple[List[str], List[str]]): + super().__init__(name_mapping) + + # Ensure that the input and output name mappings are the same length + if len(name_mapping[0]) != len(name_mapping[1]): + raise ValueError("Input and output name mappings must have the same length.") + + # The transform function simply returns the input as-is + def null_transform(x: Dict[str, Float]) -> Dict[str, Float]: + return {key: x[key] for key in name_mapping[0]} + + self.transform_func = null_transform \ No newline at end of file diff --git a/src/jimgw/population/utils.py b/src/jimgw/population/utils.py new file mode 100644 index 00000000..7dc5b241 --- /dev/null +++ b/src/jimgw/population/utils.py @@ -0,0 +1,61 @@ +import importlib +import jax +import jax.numpy as jnp +import numpy as np +import pandas as pd +from jimgw.single_event.utils import Mc_eta_to_m1_m2 +import glob + +def create_model(model_name): + try: + module = importlib.import_module('population_model') + + # Check if model_name is a string + if not isinstance(model_name, str): + raise ValueError("model_name must be a string") + + model_class = getattr(module, model_name) + return model_class + except (ImportError, AttributeError) as e: + raise ImportError(f"Could not import model '{model_name}': {str(e)}") + + +def extract_data_from_npz_files(data_dir, column_name, num_samples=50, random_seed=42): + """ + Extracts specified column data from the given .npz files. + + Parameters: + - data_dir (str): The directory containing all the .npz files. + - column_name (str): The name of the column to extract from the DataFrame. + - num_samples (int): Number of samples to extract from each file. + - random_seed (int): Seed for random number generation. + + Returns: + - jnp.array: Stacked array of extracted data. + """ + + npz_files = glob.glob(f"{data_dir}/*.npz") + key = jax.random.PRNGKey(random_seed) + result_list = [] + + for npz_file in npz_files: + print(f"Loading file: {npz_file}") + + with np.load(npz_file, allow_pickle=True) as data: + data_dict = data['arr_0'].item() + if column_name not in data_dict: + raise ValueError(f"Column '{column_name}' not found in the data.") + + extracted_data = data_dict[column_name].reshape(-1,) + + if isinstance(extracted_data, np.ndarray): + extracted_data = jax.device_put(extracted_data) + + key, subkey = jax.random.split(key) + sample_indices = jax.random.choice(subkey, extracted_data.shape[0], shape=(num_samples,), replace=True) + + sampled_data = extracted_data[sample_indices] + result_list.append(sampled_data) + stacked_array = jnp.stack(result_list) + + return stacked_array \ No newline at end of file