Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Transform #2

Merged
merged 119 commits into from
Aug 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
119 commits
Select commit Hold shift + click to select a range
77e0247
Update Single_event_runManager.py
thomasckng Jul 23, 2024
86da005
Merge pull request #105 from kazewong/main
kazewong Jul 23, 2024
540b4db
Merge branch 'kazewong:jim-dev' into jim-dev
thomasckng Jul 23, 2024
8e6d789
Merge pull request #102 from thomasckng/jim-dev
kazewong Jul 23, 2024
75c74a3
scaffolding jim to handle naming.
kazewong Jul 24, 2024
c749dbd
Rename variables in jim
kazewong Jul 24, 2024
f45b41e
Starting tracking names in jim
kazewong Jul 24, 2024
dd6e0de
Add Transform class
kazewong Jul 24, 2024
f11f8c8
Add LogitToUniform Transform
kazewong Jul 24, 2024
c1192f2
Add logit distribution
kazewong Jul 24, 2024
b3ebc5e
Add propagate)name method in transform
kazewong Jul 24, 2024
95ef89b
Rename composite to Combine for combining priors
kazewong Jul 24, 2024
e6e45ef
scaffold prior test
kazewong Jul 24, 2024
d503a37
Add Sequential Transform
kazewong Jul 24, 2024
61b3b9e
Separate logit and scaling.
kazewong Jul 24, 2024
b47e6ae
Univeriate Transform seems working
kazewong Jul 24, 2024
e1cc408
Uniform prior now working.
kazewong Jul 24, 2024
d68cb36
Added inverse transform and uniform now perform correct
kazewong Jul 24, 2024
c16eef5
Add comments to current sequential transform prior class
kazewong Jul 24, 2024
8ab92a4
Removing inverse.
kazewong Jul 25, 2024
d8b2d1f
Add transformation function
kazewong Jul 25, 2024
ca1d6b6
Combine should be working now
kazewong Jul 25, 2024
2f3f412
Sine is an illegal transform since its Jacobian could be negative
kazewong Jul 25, 2024
4978cb1
Modify Uniform and add UniformSphere
thomasckng Jul 25, 2024
c1115bd
Add Sine and Cosine Prior
thomasckng Jul 26, 2024
01a6c1e
Revert "Modify Uniform and add UniformSphere"
thomasckng Jul 26, 2024
7ee4133
Add standard normal distribution
thomasckng Jul 26, 2024
6a35792
Add periodic uniform prior
thomasckng Jul 26, 2024
1bcf32c
Reformat
thomasckng Jul 26, 2024
17450be
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 26, 2024
5d98aeb
Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-p…
thomasckng Jul 26, 2024
87ee212
Minor text change
thomasckng Jul 26, 2024
cc1448e
Remove PeriodicUniform
thomasckng Jul 26, 2024
6070f13
Use self.sample_base
thomasckng Jul 26, 2024
8a301f2
Update prior.py
thomasckng Jul 26, 2024
d761f6b
Update prior.py to include powerLaw
xuyuon Jul 26, 2024
b9a7255
Update transforms.py
xuyuon Jul 26, 2024
2f6e12a
format and updating typing hint
kazewong Jul 25, 2024
194c565
Revert "Update transforms.py"
thomasckng Jul 26, 2024
5807f3c
Revert "Update prior.py to include powerLaw"
thomasckng Jul 26, 2024
98625c3
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 26, 2024
8256d0e
Comment out old prior
thomasckng Jul 26, 2024
b42b0f4
Reformat
thomasckng Jul 26, 2024
ba7dabb
Merge pull request #111 from thomasckng/sphere_prior
thomasckng Jul 26, 2024
a9629b2
Update Prior naming
kazewong Jul 26, 2024
fbba5af
Standize Transform naming
kazewong Jul 26, 2024
4a5d932
Fixing naming problem and add base_distribution tracer
kazewong Jul 26, 2024
3497b6d
Merge pull request #2 from kazewong/98-moving-naming-tracking-into-ji…
thomasckng Jul 26, 2024
27dc870
Updated powerlaw
xuyuon Jul 26, 2024
f7b883d
Updated powerlaw
xuyuon Jul 26, 2024
1665b1c
Updated powerlaw
xuyuon Jul 26, 2024
f51847d
Updated powerlaw
xuyuon Jul 26, 2024
f7876e6
Updated prior.py to include UniformComponenChirpMassPrior
xuyuon Jul 26, 2024
0b2c7cf
Fix priors
thomasckng Jul 27, 2024
d917f96
Add prior tests
thomasckng Jul 27, 2024
9390495
Merge branch 'prior-test' into 98-moving-naming-tracking-into-jim-cla…
thomasckng Jul 27, 2024
142fe54
Set constraint on powerlaw prior input
xuyuon Jul 29, 2024
fa6a6e2
Added test_power_law
xuyuon Jul 29, 2024
a8c0673
Updated ParetoTransform to avoid divide by zero
xuyuon Jul 29, 2024
6ce5b36
Revert update on ParetoTransform
xuyuon Jul 29, 2024
e9511b9
Updated test_prior.py
xuyuon Jul 29, 2024
54082ea
Updated test_prior.py
xuyuon Jul 29, 2024
9ad64a1
Updated test_prior.py
xuyuon Jul 29, 2024
ce437c4
Updated test_prior.py
xuyuon Jul 29, 2024
18d5184
Updated test_prior.py
xuyuon Jul 29, 2024
7c05d14
Updated test_prior.py
xuyuon Jul 29, 2024
f647bf0
Remove unnecessary test
thomasckng Jul 29, 2024
6171ce5
Reformat
thomasckng Jul 29, 2024
ab7447b
Change naming
thomasckng Jul 29, 2024
1b67056
Removed bilby
xuyuon Jul 29, 2024
b2dd606
Merge pull request #113 from xuyuon/98-moving-naming-tracking-into-ji…
kazewong Jul 29, 2024
9d59c4d
Move unit test to unit directory
kazewong Jul 29, 2024
b3110df
Update priors naming issue
kazewong Jul 29, 2024
024c5c3
base parameter names seem working
kazewong Jul 29, 2024
ac43d54
fix cosine naming error
kazewong Jul 29, 2024
151c37d
prior test now all pass
kazewong Jul 29, 2024
576ce7c
Fix likelihood sampling issue due to parameter name transformation
kazewong Jul 29, 2024
7625dbe
Instead of defining univariate and multivariate,
kazewong Jul 30, 2024
9ac722f
Adding some transform, should be working
kazewong Jul 30, 2024
935b314
Refactor into NtoN and NtoM transform
kazewong Jul 30, 2024
685670a
Fix bugs in ArcSine
kazewong Jul 30, 2024
39126f5
Add inverse mass transform
thomasckng Jul 30, 2024
46bd044
Update sequential prior class and test_GW15014.py.
kazewong Jul 31, 2024
a0161ee
Merge pull request #5 from kazewong/98-moving-naming-tracking-into-ji…
thomasckng Jul 31, 2024
8e1441b
correct sign errors and minior bugs
kazewong Jul 31, 2024
dfdfffa
Add mass transform
thomasckng Jul 31, 2024
9d87e58
Add simplex transform
thomasckng Jul 31, 2024
5f33346
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
52fa0eb
Change transform to take dictionary as input for transform_func
kazewong Jul 31, 2024
1e389bb
Add UniformComponentMassPrior
thomasckng Jul 31, 2024
0252392
prior system should work with dictionary function now
kazewong Jul 31, 2024
cb75bb9
update transform
kazewong Jul 31, 2024
fd32f20
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
561e628
Move subclassing structure
kazewong Jul 31, 2024
1aaf5ea
Remove transformation
thomasckng Jul 31, 2024
5bfdda1
Remove prior
thomasckng Jul 31, 2024
77a6ad1
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
2401df8
Add mass transform
thomasckng Jul 31, 2024
f932c77
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Jul 31, 2024
ff35a82
Add Bound transforming transform
kazewong Jul 31, 2024
47af9cf
no nans now, but the code can be consolidate
kazewong Aug 1, 2024
b5e06a6
Solve conflict
thomasckng Aug 1, 2024
b7d08d3
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
8e5b326
Add sky position transform
thomasckng Aug 1, 2024
10d51b2
Modify sky position transform
thomasckng Aug 1, 2024
8368d00
Change util func name
thomasckng Aug 1, 2024
b4f6052
Merge branch '98-moving-naming-tracking-into-jim-class-from-prior-cla…
thomasckng Aug 1, 2024
1cb0d11
Revert "Merge branch '98-moving-naming-tracking-into-jim-class-from-p…
thomasckng Aug 1, 2024
78e93c5
Merge
thomasckng Aug 1, 2024
080bd8b
Modify integration test
thomasckng Aug 1, 2024
4058d32
Reformat
thomasckng Aug 1, 2024
cdf771d
Add typecheck
thomasckng Aug 1, 2024
02c5650
minor typo
thomasckng Aug 1, 2024
ce7ac34
Rename sampler
thomasckng Aug 1, 2024
7d44aa4
Fix test
thomasckng Aug 1, 2024
e9288c8
Fix BoundToUnbound transform
thomasckng Aug 1, 2024
fe500a7
Use ifos list
thomasckng Aug 1, 2024
0d28520
Fix jim summary and get_samples
thomasckng Aug 1, 2024
f7e3fe8
Fix jim output functions
thomasckng Aug 1, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 1 addition & 5 deletions .github/workflows/python-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -3,11 +3,7 @@

name: Python package

on:
push:
branches: [ "main" ]
pull_request:
branches: [ "main" ]
on: [push, pull_request]

jobs:
build:
Expand Down
Empty file added docs/tutorials/naming_system.md
Empty file.
1 change: 0 additions & 1 deletion example/Single_event_runManager.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@

run = SingleEventRun(
seed=0,
path="test_data/GW150914/",
detectors=["H1", "L1"],
priors={
"M_c": {"name": "Uniform", "xmin": 10.0, "xmax": 80.0},
Expand Down
162 changes: 131 additions & 31 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,17 +8,51 @@

from jimgw.base import LikelihoodBase
from jimgw.prior import Prior
from jimgw.transforms import BijectiveTransform, NtoMTransform


class Jim(object):
"""
Master class for interfacing with flowMC

"""

def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):
self.Likelihood = likelihood
self.Prior = prior
likelihood: LikelihoodBase
prior: Prior

# Name of parameters to sample from
sample_transforms: list[BijectiveTransform]
likelihood_transforms: list[NtoMTransform]
parameter_names: list[str]
sampler: Sampler

def __init__(
self,
likelihood: LikelihoodBase,
prior: Prior,
sample_transforms: list[BijectiveTransform] = [],
likelihood_transforms: list[NtoMTransform] = [],
**kwargs,
):
self.likelihood = likelihood
self.prior = prior

self.sample_transforms = sample_transforms
self.likelihood_transforms = likelihood_transforms
self.parameter_names = prior.parameter_names

if len(sample_transforms) == 0:
print(
"No sample transforms provided. Using prior parameters as sampling parameters"
)
else:
print("Using sample transforms")
for transform in sample_transforms:
self.parameter_names = transform.propagate_name(self.parameter_names)

if len(likelihood_transforms) == 0:
print(
"No likelihood transforms provided. Using prior parameters as likelihood parameters"
)

seed = kwargs.get("seed", 0)

Expand All @@ -33,30 +67,50 @@ def __init__(self, likelihood: LikelihoodBase, prior: Prior, **kwargs):

rng_key, subkey = jax.random.split(rng_key)
model = MaskedCouplingRQSpline(
self.Prior.n_dim, num_layers, hidden_size, num_bins, subkey
self.prior.n_dim, num_layers, hidden_size, num_bins, subkey
)

self.Sampler = Sampler(
self.Prior.n_dim,
self.sampler = Sampler(
self.prior.n_dim,
rng_key,
None, # type: ignore
local_sampler,
model,
**kwargs,
)

def add_name(self, x: Float[Array, " n_dim"]) -> dict[str, Float]:
"""
Turn an array into a dictionary

Parameters
----------
x : Array
An array of parameters. Shape (n_dim,).
"""

return dict(zip(self.parameter_names, x))

def posterior(self, params: Float[Array, " n_dim"], data: dict):
prior_params = self.Prior.add_name(params.T)
prior = self.Prior.log_prob(prior_params)
named_params = self.add_name(params)
transform_jacobian = 0.0
for transform in self.sample_transforms:
named_params, jacobian = transform.inverse(named_params)
transform_jacobian += jacobian
prior = self.prior.log_prob(named_params) + transform_jacobian
for transform in self.likelihood_transforms:
named_params = transform.forward(named_params)
return (
self.Likelihood.evaluate(self.Prior.transform(prior_params), data) + prior
self.likelihood.evaluate(named_params, data) + prior
)

def sample(self, key: PRNGKeyArray, initial_guess: Array = jnp.array([])):
if initial_guess.size == 0:
initial_guess_named = self.Prior.sample(key, self.Sampler.n_chains)
initial_guess_named = self.prior.sample(key, self.sampler.n_chains)
for transform in self.sample_transforms:
initial_guess_named = jax.vmap(transform.forward)(initial_guess_named)
initial_guess = jnp.stack([i for i in initial_guess_named.values()]).T
self.Sampler.sample(initial_guess, None) # type: ignore
self.sampler.sample(initial_guess, None) # type: ignore

def maximize_likelihood(
self,
Expand All @@ -67,7 +121,7 @@ def maximize_likelihood(
):
key = jax.random.PRNGKey(seed)
set_nwalkers = set_nwalkers
initial_guess = self.Prior.sample(key, set_nwalkers)
initial_guess = self.prior.sample(key, set_nwalkers)

def negative_posterior(x: Float[Array, " n_dim"]):
return -self.posterior(x, None) # type: ignore since flowMC does not have typing info, yet
Expand All @@ -78,41 +132,67 @@ def negative_posterior(x: Float[Array, " n_dim"]):
print("Done compiling")

print("Starting the optimizer")
optimizer = EvolutionaryOptimizer(self.Prior.n_dim, verbose=True)
optimizer = EvolutionaryOptimizer(self.prior.n_dim, verbose=True)
_ = optimizer.optimize(negative_posterior, bounds, n_loops=n_loops)
best_fit = optimizer.get_result()[0]
return best_fit

def print_summary(self, transform: bool = True):
def print_summary(self):
"""
Generate summary of the run

"""

train_summary = self.Sampler.get_sampler_state(training=True)
production_summary = self.Sampler.get_sampler_state(training=False)
train_summary = self.sampler.get_sampler_state(training=True)
production_summary = self.sampler.get_sampler_state(training=False)

training_chain = train_summary["chains"].reshape(-1, self.Prior.n_dim).T
training_chain = self.Prior.add_name(training_chain)
if transform:
training_chain = self.Prior.transform(training_chain)
training_chain = train_summary["chains"].reshape(-1, len(self.parameter_names)).T
if self.sample_transforms:
transformed_chain = {}
named_sample = self.add_name(training_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in training_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
training_chain = transformed_chain
else:
training_chain = self.add_name(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
training_loss = train_summary["loss_vals"]

production_chain = production_summary["chains"].reshape(-1, self.Prior.n_dim).T
production_chain = self.Prior.add_name(production_chain)
if transform:
production_chain = self.Prior.transform(production_chain)
production_chain = production_summary["chains"].reshape(-1, len(self.parameter_names)).T
if self.sample_transforms:
transformed_chain = {}
named_sample = self.add_name(production_chain[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in production_chain[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
production_chain = transformed_chain
else:
production_chain = self.add_name(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]

print("Training summary")
print("=" * 10)
for key, value in training_chain.items():
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}")
print(
f"Log probability: {training_log_prob.mean():.3f} +/- {training_log_prob.std():.3f}"
)
Expand All @@ -129,7 +209,7 @@ def print_summary(self, transform: bool = True):
print("Production summary")
print("=" * 10)
for key, value in production_chain.items():
print(f"{key}: {value.mean():.3f} +/- {value.std():.3f}")
print(f"{key}: {jnp.array(value).mean():.3f} +/- {jnp.array(value).std():.3f}")
print(
f"Log probability: {production_log_prob.mean():.3f} +/- {production_log_prob.std():.3f}"
)
Expand All @@ -156,12 +236,32 @@ def get_samples(self, training: bool = False) -> dict:

"""
if training:
chains = self.Sampler.get_sampler_state(training=True)["chains"]
chains = self.sampler.get_sampler_state(training=True)["chains"]
else:
chains = self.sampler.get_sampler_state(training=False)["chains"]

# Need rewrite to output chains instead of flattened samples
chains = chains.reshape(-1, len(self.parameter_names)).T
if self.sample_transforms:
transformed_chain = {}
named_sample = self.add_name(chains[0])
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key] = [value]
for sample in chains[1:]:
named_sample = self.add_name(sample)
for transform in self.sample_transforms:
named_sample = transform.backward(named_sample)
for key, value in named_sample.items():
transformed_chain[key].append(value)
output = transformed_chain
else:
chains = self.Sampler.get_sampler_state(training=False)["chains"]
output = self.add_name(chains)

chains = self.Prior.transform(self.Prior.add_name(chains.transpose(2, 0, 1)))
return chains
for key in output.keys():
output[key] = jnp.array(output[key])
return output

def plot(self):
pass
Loading
Loading