diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 0e439fd0..62ffe71e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -576,18 +576,34 @@ def y(x): key = jax.random.PRNGKey(0) initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = prior.sample(subkey, popsize) for transform in sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position )