diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..acb431f5 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -575,18 +575,34 @@ def y(x): ) key = jax.random.PRNGKey(0) - initial_position = [] - for _ in range(popsize): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = prior.sample(key, 1) - for transform in sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_position.append(guess) - initial_position = jnp.array(initial_position) + + initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + + key, subkey = jax.random.split(key) + guess = prior.sample(subkey, popsize) + for transform in sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) + rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position )