Skip to content

Commit

Permalink
Merge pull request #108 from kazewong/98-moving-naming-tracking-into-…
Browse files Browse the repository at this point in the history
…jim-class-from-prior-class

98 moving naming tracking into jim class from prior class
  • Loading branch information
kazewong authored Sep 2, 2024
2 parents a46af6b + 93ea485 commit d2c0416
Show file tree
Hide file tree
Showing 20 changed files with 2,357 additions and 709 deletions.
6 changes: 1 addition & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
on: [push, pull_request]

jobs:
build:
Expand Down
Empty file added docs/tutorials/naming_system.md
Empty file.
127 changes: 96 additions & 31 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,51 @@

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


class Jim(object):
"""
Master class for interfacing with flowMC
"""

def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior
likelihood: LikelihoodBase
prior: Prior

# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
self,
likelihood: LikelihoodBase,
prior: Prior,
sample_transforms: list[BijectiveTransform] = [],
likelihood_transforms: list[NtoMTransform] = [],
**kwargs,
):
self.likelihood = likelihood
self.prior = prior

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand All @@ -33,30 +67,56 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.Prior.n_dim,
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
local_sampler,
model,
**kwargs,
)

def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior_params = self.Prior.add_name(params.T)
prior = self.Prior.log_prob(prior_params)
return (
self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior
)
def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary
def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
"""

return dict(zip(self.parameter_names, x))

def posterior(self, params: Float[Array, " n_dim"], data: dict):
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in reversed(self.sample_transforms):
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
return self.likelihood.evaluate(named_params, data) + prior

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_guess = []
for _ in range(self.sampler.n_chains):
flag = True
while flag:
key = jax.random.split(key)[1]
guess = self.prior.sample(key, 1)
for transform in self.sample_transforms:
guess = transform.forward(guess)
guess = jnp.array([i for i in guess.values()]).T[0]
flag = not jnp.all(jnp.isfinite(guess))
initial_guess.append(guess)
initial_position = jnp.array(initial_guess)
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
self,
Expand All @@ -67,7 +127,7 @@ def maximize_likelihood(
):
key = jax.random.PRNGKey(seed)
set_nwalkers = set_nwalkers
initial_guess = self.Prior.sample(key, set_nwalkers)
initial_guess = self.prior.sample(key, set_nwalkers)

def negative_posterior(x: Float[Array, " n_dim"]):
return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet
Expand All @@ -78,7 +138,7 @@ def negative_posterior(x: Float[Array, " n_dim"]):
print("Done compiling")

print("Starting the optimizer")
optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True)
optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True)
_ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return best_fit
Expand All @@ -89,22 +149,24 @@ def print_summary(self, transform: bool = True):
"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T
training_chain = self.Prior.add_name(training_chain)
training_chain = train_summary["chains"].reshape(-1, self.prior.n_dim).T
training_chain = self.add_name(training_chain)
if transform:
training_chain = self.Prior.transform(training_chain)
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T
production_chain = self.Prior.add_name(production_chain)
production_chain = production_summary["chains"].reshape(-1, self.prior.n_dim).T
production_chain = self.add_name(production_chain)
if transform:
production_chain = self.Prior.transform(production_chain)
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -156,11 +218,14 @@ def get_samples(self, training: bool = False) -> dict:
"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1)))
chains = chains.transpose(2, 0, 1)
chains = self.add_name(chains)
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
return chains

def plot(self):
Expand Down
Loading

0 comments on commit d2c0416

Please sign in to comment.