Skip to content

Commit

Permalink
Update jim.py
Browse files Browse the repository at this point in the history
  • Loading branch information
thomasckng committed Sep 17, 2024
1 parent 3d3b20a commit 5cf1668
Showing 1 changed file with 0 additions and 26 deletions.
26 changes: 0 additions & 26 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@
from flowMC.nfmodel.rqSpline import MaskedCouplingRQSpline
from flowMC.proposal.MALA import MALA
from flowMC.Sampler import Sampler
from flowMC.utils.EvolutionaryOptimizer import EvolutionaryOptimizer
from jaxtyping import Array, Float, PRNGKeyArray

from jimgw.base import LikelihoodBase
Expand Down Expand Up @@ -119,31 +118,6 @@ def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
self,
bounds: Float[Array, " n_dim 2"],
set_nwalkers: int = 100,
n_loops: int = 2000,
seed=92348,
):
key = jax.random.PRNGKey(seed)
set_nwalkers = set_nwalkers
initial_guess = self.prior.sample(key, set_nwalkers)

def negative_posterior(x: Float[Array, " n_dim"]):
return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet

negative_posterior = jax.jit(jax.vmap(negative_posterior))
print("Compiling likelihood function")
negative_posterior(initial_guess)
print("Done compiling")

print("Starting the optimizer")
optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True)
_ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return best_fit

def print_summary(self, transform: bool = True):
"""
Generate summary of the run
Expand Down

0 comments on commit 5cf1668

Please sign in to comment.