Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add population analysis functionality #148

Open
wants to merge 11 commits into
base: 50-add-population-analysis-functionality
Choose a base branch
from
1 change: 1 addition & 0 deletions src/jimgw/population/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@

81 changes: 81 additions & 0 deletions src/jimgw/population/example_population.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,81 @@
import jax
import jax.numpy as jnp

from jimgw.jim import Jim
from flowMC.strategy.optimization import optimization_Adam
from jimgw.population.population_likelihood import PopulationLikelihood
from jimgw.population.utils import create_model
import argparse
from jimgw.prior import UniformPrior, CombinePrior
from jimgw.transforms import BoundToUnbound
from jimgw.population.transform import NullTransform

jax.config.update("jax_enable_x64", True)

def parse_args():
parser = argparse.ArgumentParser(description='Run population likelihood sampling.')
parser.add_argument('--pop_model', type=str, required=True, help='Population model to use.')
parser.add_argument('--data_dir', type=str, required=True, help='Directory containing the NPZ data files.')
return parser.parse_args()

def main():
args = parse_args()
model = create_model(args.pop_model)
pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model)

local_sampler_arg = {"step_size": 3e-3}

Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1)

""""
The following needs changing
"""
m_min_prior = UniformPrior(10.,80.,parameter_names = ["m_min"])
m_max_prior = UniformPrior(10.,80.,parameter_names = ["m_max"])
alpha_prior = UniformPrior(0.,10.,parameter_names = ["alpha"])

prior = CombinePrior([m_min_prior, m_max_prior, alpha_prior])

sample_transforms = [BoundToUnbound(name_mapping = [["m_min"], ["m_min_unbounded"]], original_lower_bound=10, original_upper_bound=80),
BoundToUnbound(name_mapping = [["m_max"], ["m_max_unbounded"]], original_lower_bound=10, original_upper_bound=80),
BoundToUnbound(name_mapping = [["alpha"], ["alpha_unbounded"]], original_lower_bound=0, original_upper_bound =10)]
name_mapping = (
["m_min", "m_max", "alpha"],
["m_min", "m_max", "alpha"]
)
likelihood_transforms = [NullTransform(name_mapping)]

n_epochs = 2
n_loop_training = 1
learning_rate = 1e-4


jim = Jim(
pop_likelihood,
prior,
sample_transforms,
likelihood_transforms ,
n_loop_training=n_loop_training,
n_loop_production=1,
n_local_steps=5,
n_global_steps=5,
n_chains=4,
n_epochs=n_epochs,
learning_rate=learning_rate,
n_max_examples=30,
n_flow_samples=100,
momentum=0.9,
batch_size=100,
use_global=True,
train_thinning=1,
output_thinning=1,
local_sampler_arg=local_sampler_arg,
strategies=[Adam_optimizer, "default"],
)

jim.sample(jax.random.PRNGKey(42))
jim.get_samples()
jim.print_summary()

if __name__ == "__main__":
main()
23 changes: 23 additions & 0 deletions src/jimgw/population/population_likelihood.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,23 @@
import jax
import jax.numpy as jnp
from jaxtyping import Float

from jimgw.base import LikelihoodBase
from jimgw.population.utils import extract_data_from_npz_files

class PopulationLikelihood(LikelihoodBase):
def __init__(self, data_dir, column_name, num_samples, model_class):
self.posteriors = extract_data_from_npz_files(data_dir, column_name, num_samples, random_seed=42)
self.population_model = model_class()

def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float:
model_output = self.population_model.evaluate(pop_params, self.posteriors)
log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1)))
return log_likelihood







35 changes: 35 additions & 0 deletions src/jimgw/population/population_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
from abc import ABC, abstractmethod
import jax.numpy as jnp
from jaxtyping import Float

class PopulationModelBase(ABC):
@abstractmethod
def __init__(self, *params):
self.params = params

@abstractmethod
def evaluate(self, pop_params: dict, data: dict) -> Float:
"""
Evaluate the likelihood for a given set of parameters.
"""
raise NotImplementedError


class TruncatedPowerLawModel(PopulationModelBase):
def __init__(self):
super().__init__()

def truncated_power_law(self, x, x_min, x_max, alpha):
valid_indices = (x >= x_min) & (x <= x_max)
C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha))
pdf = jnp.zeros_like(x)
pdf = jnp.where(valid_indices, C / (x ** alpha), pdf)
return pdf

def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float:
return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"])





21 changes: 21 additions & 0 deletions src/jimgw/population/transform.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
from typing import Tuple, List, Dict
from jimgw.transforms import NtoNTransform, Float


class NullTransform(NtoNTransform):
"""
Null transformation that does nothing to the input data.
"""

def __init__(self, name_mapping: Tuple[List[str], List[str]]):
super().__init__(name_mapping)

# Ensure that the input and output name mappings are the same length
if len(name_mapping[0]) != len(name_mapping[1]):
raise ValueError("Input and output name mappings must have the same length.")

# The transform function simply returns the input as-is
def null_transform(x: Dict[str, Float]) -> Dict[str, Float]:
return {key: x[key] for key in name_mapping[0]}

self.transform_func = null_transform
61 changes: 61 additions & 0 deletions src/jimgw/population/utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import importlib
import jax
import jax.numpy as jnp
import numpy as np
import pandas as pd
from jimgw.single_event.utils import Mc_eta_to_m1_m2
import glob

def create_model(model_name):
try:
module = importlib.import_module('population_model')

# Check if model_name is a string
if not isinstance(model_name, str):
raise ValueError("model_name must be a string")

model_class = getattr(module, model_name)
return model_class
except (ImportError, AttributeError) as e:
raise ImportError(f"Could not import model '{model_name}': {str(e)}")


def extract_data_from_npz_files(data_dir, column_name, num_samples=50, random_seed=42):
"""
Extracts specified column data from the given .npz files.

Parameters:
- data_dir (str): The directory containing all the .npz files.
- column_name (str): The name of the column to extract from the DataFrame.
- num_samples (int): Number of samples to extract from each file.
- random_seed (int): Seed for random number generation.

Returns:
- jnp.array: Stacked array of extracted data.
"""

npz_files = glob.glob(f"{data_dir}/*.npz")
key = jax.random.PRNGKey(random_seed)
result_list = []

for npz_file in npz_files:
print(f"Loading file: {npz_file}")

with np.load(npz_file, allow_pickle=True) as data:
data_dict = data['arr_0'].item()
if column_name not in data_dict:
raise ValueError(f"Column '{column_name}' not found in the data.")

extracted_data = data_dict[column_name].reshape(-1,)

if isinstance(extracted_data, np.ndarray):
extracted_data = jax.device_put(extracted_data)

key, subkey = jax.random.split(key)
sample_indices = jax.random.choice(subkey, extracted_data.shape[0], shape=(num_samples,), replace=True)

sampled_data = extracted_data[sample_indices]
result_list.append(sampled_data)
stacked_array = jnp.stack(result_list)

return stacked_array
Loading