diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..62ffe71e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -563,9 +563,9 @@ def maximize_likelihood( def y(x): named_params = dict(zip(parameter_names, x)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return -self.evaluate_original(named_params, {}) print("Starting the optimizer") @@ -575,18 +575,35 @@ def y(x): ) key = jax.random.PRNGKey(0) - initial_position = [] - for _ in range(popsize): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = prior.sample(key, 1) - for transform in sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_position.append(guess) - initial_position = jnp.array(initial_position) + initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan + + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] + + key, subkey = jax.random.split(key) + guess = prior.sample(subkey, popsize) + for transform in sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) @@ -595,9 +612,9 @@ def y(x): named_params = dict(zip(parameter_names, best_fit)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return named_params