Skip to content

Commit

Permalink
Update the example script. Pop analysis now works with jim
Browse files Browse the repository at this point in the history
  • Loading branch information
CharmaineWONG2 committed Sep 9, 2024
1 parent 800f127 commit 29b4236
Show file tree
Hide file tree
Showing 4 changed files with 19 additions and 56 deletions.
20 changes: 5 additions & 15 deletions src/jimgw/population/example_population.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,22 +20,10 @@ def parse_args():

def main():
args = parse_args()
mass1_array = extract_data_from_npz_files(args.data_dir,"m_1", num_samples=5000, random_seed=42)

"""
need changes for the pop_likelihood
"""
# def pop_likelihood(pop_params ,data):
# model = create_model(args.pop_model)
# likelihood = PopulationLikelihood(mass1_array, model, pop_params)
# log_likelihood = likelihood.evaluate(mass1_array, pop_params)
# return log_likelihood

model = create_model(args.pop_model)
pop_params = ["m_min",1,2]
pop_likelihood = PopulationLikelihood(mass1_array, model, pop_params)
pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model)

mass_matrix = jnp.eye(11)
mass_matrix = jnp.eye(model.get_pop_params_dimension())
mass_matrix = mass_matrix.at[1, 1].set(1e-3)
mass_matrix = mass_matrix.at[5, 5].set(1e-3)
local_sampler_arg = {"step_size": mass_matrix * 3e-3}
Expand All @@ -48,7 +36,9 @@ def main():
m_min_prior = UniformPrior(10.,80.,parameter_names = ["m_min"])
m_max_prior = UniformPrior(10.,80.,parameter_names = ["m_max"])
alpha_prior = UniformPrior(0.,10.,parameter_names = ["alpha"])

prior = CombinePrior([m_min_prior, m_max_prior, alpha_prior])

sample_transforms = [BoundToUnbound(name_mapping = [["m_min"], ["m_min_unbounded"]], original_lower_bound=10, original_upper_bound=80),
BoundToUnbound(name_mapping = [["m_max"], ["m_max_unbounded"]], original_lower_bound=10, original_upper_bound=80),
BoundToUnbound(name_mapping = [["alpha"], ["alpha_unbounded"]], original_lower_bound=0, original_upper_bound =10)]
Expand Down Expand Up @@ -87,7 +77,7 @@ def main():
)

jim.sample(jax.random.PRNGKey(42))
samples =jim.get_samples()
jim.get_samples()
jim.print_summary()

if __name__ == "__main__":
Expand Down
11 changes: 6 additions & 5 deletions src/jimgw/population/population_likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,14 +3,15 @@
from jaxtyping import Float

from jimgw.base import LikelihoodBase
from jimgw.population.utils import extract_data_from_npz_files

class PopulationLikelihood(LikelihoodBase):
def __init__(self, mass_array, model_class, pop_params):
self.mass_array = mass_array
self.population_model = model_class(*pop_params)
def __init__(self, data_dir, column_name, num_samples, model_class):
self.posteriors = extract_data_from_npz_files(data_dir, column_name, num_samples, random_seed=42)
self.population_model = model_class()

def evaluate(self, pop_params: dict[str, Float],posteriors: dict) -> Float:
model_output = self.population_model.evaluate(pop_params, posteriors)
def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float:
model_output = self.population_model.evaluate(pop_params, self.posteriors)
log_likelihood = jnp.sum(jnp.log(jnp.mean(model_output, axis=1)))
return log_likelihood

Expand Down
42 changes: 8 additions & 34 deletions src/jimgw/population/population_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,55 +8,29 @@ def __init__(self, *params):
self.params = params

@abstractmethod
def evaluate(self,data: dict, pop_params: dict) -> Float:
def evaluate(self, pop_params: dict, data: dict) -> Float:
"""
Evaluate the likelihood for a given set of parameters.
"""
raise NotImplementedError

# class TruncatedPowerLawModel(PopulationModelBase):
# def __init__(self, *params):
# super().__init__(*params)

# def truncated_power_law(self, x, x_min, x_max, alpha):
# valid_indices = (x >= x_min) & (x <= x_max)
# C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha))

# # Ensure x is treated properly and avoid non-concrete indexing
# pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x
# pdf = jnp.where(valid_indices, C / (x ** alpha), pdf)

# return pdf

# def evaluate(self,data: dict, pop_params: dict) -> Float:
# """Evaluate the truncated power law model with dynamic parameters."""
# x_min = pop_params[0]
# x_max = pop_params[1]
# alpha = pop_params[2]

# return self.truncated_power_law(data, x_min, x_max, alpha)

class TruncatedPowerLawModel(PopulationModelBase):
def __init__(self, *params):
super().__init__(*params)
def __init__(self):
super().__init__()

def truncated_power_law(self, x, x_min, x_max, alpha):
valid_indices = (x >= x_min) & (x <= x_max)
C = (1 - alpha) / (x_max**(1 - alpha) - x_min**(1 - alpha))

# Ensure x is treated properly and avoid non-concrete indexing
pdf = jnp.zeros_like(x) # Initialize pdf to the same shape as x
pdf = jnp.zeros_like(x)
pdf = jnp.where(valid_indices, C / (x ** alpha), pdf)

return pdf

def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float:
"""Evaluate the truncated power law model with dynamic parameters."""
print("pop_parmas",pop_params)
m_min = pop_params["m_min"]
m_max = pop_params["m_max"]
alpha = pop_params["alpha"]
return self.truncated_power_law(data, m_min, m_max, alpha)
return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"])

def get_pop_params_dimension():
return 3



Expand Down
2 changes: 0 additions & 2 deletions src/jimgw/population/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,8 +47,6 @@ def extract_data_from_npz_files(data_dir, column_name, num_samples=50, random_se
raise ValueError(f"Column '{column_name}' not found in the data.")

extracted_data = data_dict[column_name].reshape(-1,)
print(extracted_data)
print(extracted_data.shape)

if isinstance(extracted_data, np.ndarray):
extracted_data = jax.device_put(extracted_data)
Expand Down

0 comments on commit 29b4236

Please sign in to comment.