From 88322e9a312fbfd84b5ffcb650cc0055424af5c1 Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 17 Sep 2024 08:49:54 +0800 Subject: [PATCH 1/2] Update likelihood.py --- src/jimgw/single_event/likelihood.py | 33 ++++++++++++++-------------- 1 file changed, 17 insertions(+), 16 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 00e6ce6b..0e439fd0 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -563,9 +563,9 @@ def maximize_likelihood( def y(x): named_params = dict(zip(parameter_names, x)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return -self.evaluate_original(named_params, {}) print("Starting the optimizer") @@ -575,18 +575,19 @@ def y(x): ) key = jax.random.PRNGKey(0) - initial_position = [] - for _ in range(popsize): - flag = True - while flag: - key = jax.random.split(key)[1] - guess = prior.sample(key, 1) - for transform in sample_transforms: - guess = transform.forward(guess) - guess = jnp.array([i for i in guess.values()]).T[0] - flag = not jnp.all(jnp.isfinite(guess)) - initial_position.append(guess) - initial_position = jnp.array(initial_position) + initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan + + while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): + non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + + key, subkey = jax.random.split(key) + guess = prior.sample(subkey, popsize) + for transform in sample_transforms: + guess = jax.vmap(transform.forward)(guess) + guess = jnp.array(jax.tree.leaves({key: guess[key] for key in parameter_names})).T + finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + common_length = min(len(finite_guess), len(non_finite_index)) + initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position ) @@ -595,9 +596,9 @@ def y(x): named_params = dict(zip(parameter_names, best_fit)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return named_params From 2acad453c547cb32c2e5a99f9e47f1d55c336e7c Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Tue, 17 Sep 2024 08:53:46 +0800 Subject: [PATCH 2/2] Reformat --- src/jimgw/single_event/likelihood.py | 28 ++++++++++++++++++++++------ 1 file changed, 22 insertions(+), 6 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index 0e439fd0..62ffe71e 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -576,18 +576,34 @@ def y(x): key = jax.random.PRNGKey(0) initial_position = jnp.zeros((popsize, prior.n_dim)) + jnp.nan - - while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all(): - non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0] + + while not jax.tree.reduce( + jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position) + ).all(): + non_finite_index = jnp.where( + jnp.any( + ~jax.tree.reduce( + jnp.logical_and, + jax.tree.map(lambda x: jnp.isfinite(x), initial_position), + ), + axis=1, + ) + )[0] key, subkey = jax.random.split(key) guess = prior.sample(subkey, popsize) for transform in sample_transforms: guess = jax.vmap(transform.forward)(guess) - guess = jnp.array(jax.tree.leaves({key: guess[key] for key in parameter_names})).T - finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0] + guess = jnp.array( + jax.tree.leaves({key: guess[key] for key in parameter_names}) + ).T + finite_guess = jnp.where( + jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1) + )[0] common_length = min(len(finite_guess), len(non_finite_index)) - initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length]) + initial_position = initial_position.at[ + non_finite_index[:common_length] + ].set(guess[:common_length]) rng_key, optimized_positions, summary = optimizer.optimize( jax.random.PRNGKey(12094), y, initial_position )