diff --git a/src/jimgw/population/example_population.py b/src/jimgw/population/example_population.py index b6ea942b..6baa4c09 100644 --- a/src/jimgw/population/example_population.py +++ b/src/jimgw/population/example_population.py @@ -4,7 +4,7 @@ from jimgw.jim import Jim from flowMC.strategy.optimization import optimization_Adam from jimgw.population.population_likelihood import PopulationLikelihood -from jimgw.population.utils import create_model, extract_data_from_npz_files +from jimgw.population.utils import create_model import argparse from jimgw.prior import UniformPrior, CombinePrior from jimgw.transforms import BoundToUnbound @@ -23,10 +23,7 @@ def main(): model = create_model(args.pop_model) pop_likelihood = PopulationLikelihood(args.data_dir, "m_1", 5000, model) - mass_matrix = jnp.eye(model.get_pop_params_dimension()) - mass_matrix = mass_matrix.at[1, 1].set(1e-3) - mass_matrix = mass_matrix.at[5, 5].set(1e-3) - local_sampler_arg = {"step_size": mass_matrix * 3e-3} + local_sampler_arg = {"step_size": 3e-3} Adam_optimizer = optimization_Adam(n_steps=5, learning_rate=0.01, noise_level=1) diff --git a/src/jimgw/population/population_model.py b/src/jimgw/population/population_model.py index 9f8b404d..71ebad60 100644 --- a/src/jimgw/population/population_model.py +++ b/src/jimgw/population/population_model.py @@ -29,8 +29,6 @@ def truncated_power_law(self, x, x_min, x_max, alpha): def evaluate(self, pop_params: dict[str, Float], data: dict) -> Float: return self.truncated_power_law(data, pop_params["m_min"], pop_params["m_max"],pop_params["alpha"]) - def get_pop_params_dimension(): - return 3