Skip to content

Commit

Permalink
Reformat
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasckng committed Sep 17, 2024
1 parent 88322e9 commit 2acad45
Showing 1 changed file with 22 additions and 6 deletions.
28 changes: 22 additions & 6 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -576,18 +576,34 @@ def y(x):

key = jax.random.PRNGKey(0)
initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan

while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0]

while not jax.tree.reduce(
jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]

key, subkey = jax.random.split(key)
guess = prior.sample(subkey, popsize)
for transform in sample_transforms:
guess = jax.vmap(transform.forward)(guess)
guess = jnp.array(jax.tree.leaves({key: guess[key] for key in parameter_names})).T
finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0]
guess = jnp.array(
jax.tree.leaves({key: guess[key] for key in parameter_names})
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])
rng_key, optimized_positions, summary = optimizer.optimize(
jax.random.PRNGKey(12094), y, initial_position
)
Expand Down

0 comments on commit 2acad45

Please sign in to comment.