Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reparameterization of extrinsic parameter for better sampling efficiency (clean) #161

Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
32 commits
Select commit Hold shift + click to select a range
f33a782
Adding transform from geocentric arrival time to detector arrival time
tsunhopang Aug 12, 2024
3505394
Adding transform from distance to SNR weighted distance
tsunhopang Aug 12, 2024
df75ceb
updating the typing for object attributes
tsunhopang Aug 13, 2024
9f2f52b
Adding geocentric phase to detector phase
tsunhopang Aug 13, 2024
b62970f
Adding ZeroLikelihood for testing purpose
tsunhopang Aug 13, 2024
4ea3322
Adding the missing mode 2pi for phasing transform
tsunhopang Aug 13, 2024
7a4bae0
Test wip
tsunhopang Aug 13, 2024
d5f86e5
Phase renaming
tsunhopang Aug 13, 2024
0a2e68c
wip
Aug 13, 2024
b96512c
Push conditional bijective transform
kazewong Aug 14, 2024
526e33c
Switch to using conditional transform
Aug 14, 2024
dbf3f30
Switch to using conditional transform
Aug 14, 2024
a375361
Fixing jacobian handling
Aug 14, 2024
d79af97
Both arrival phase and time transform are fully vectorized
Aug 16, 2024
bcbcbe2
Shifting distance transform to conditional
Aug 16, 2024
8dab27b
update example
Aug 16, 2024
fd33882
Fixing the single sided unbound transform
Aug 16, 2024
03e76dc
Update extrinsic test
Aug 17, 2024
8fe4b5f
bugfix for single sided transform
tsunhopang Aug 19, 2024
a19b556
Update test
Aug 19, 2024
6d2cd97
update distance transform
tsunhopang Aug 19, 2024
6993dd9
Update test
Aug 19, 2024
ff65fcf
Update arrival time transform
tsunhopang Aug 19, 2024
b98d783
Update test
Aug 19, 2024
e399a5e
Fix typo
tsunhopang Aug 19, 2024
583b759
Fix typo
tsunhopang Aug 19, 2024
d25ae42
Merge branch 'jim-dev' into extrinsic_parameter_sampling_improvement_…
tsunhopang Sep 18, 2024
3a5cbf7
remove duplicated SpinToCartesianSpinTransform
tsunhopang Sep 18, 2024
c25048f
Minor fix
tsunhopang Sep 18, 2024
4822962
replace vectorize with vmap
kazewong Sep 20, 2024
617f3e7
Try fixing github precommit
kazewong Sep 20, 2024
766c36e
Try fixing github precommit
kazewong Sep 20, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .github/workflows/pre-commit.yml
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,4 @@ jobs:
python -m pip install pytest
if [ -f requirements.txt ]; then pip install -r requirements.txt; fi
python -m pip install .
- uses: pre-commit/[email protected].0
- uses: pre-commit/[email protected].1
41 changes: 30 additions & 11 deletions src/jimgw/jim.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,19 +104,38 @@ def posterior(self, params: Float[Array, " n_dim"], data: dict):

def sample(self, key: PRNGKeyArray, initial_position: Array = jnp.array([])):
if initial_position.size == 0:
initial_position = jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan

while not jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)).all():
non_finite_index = jnp.where(jnp.any(~jax.tree.reduce(jnp.logical_and, jax.tree.map(lambda x: jnp.isfinite(x), initial_position)),axis=1))[0]
initial_position = (
jnp.zeros((self.sampler.n_chains, self.prior.n_dim)) + jnp.nan
)

while not jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
).all():
non_finite_index = jnp.where(
jnp.any(
~jax.tree.reduce(
jnp.logical_and,
jax.tree.map(lambda x: jnp.isfinite(x), initial_position),
),
axis=1,
)
)[0]

key, subkey = jax.random.split(key)
guess = self.prior.sample(subkey, self.sampler.n_chains)
for transform in self.sample_transforms:
guess = jax.vmap(transform.forward)(guess)
guess = jnp.array(jax.tree.leaves({key: guess[key] for key in self.parameter_names})).T
finite_guess = jnp.where(jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess),axis=1))[0]
guess = jnp.array(
jax.tree.leaves({key: guess[key] for key in self.parameter_names})
).T
finite_guess = jnp.where(
jnp.all(jax.tree.map(lambda x: jnp.isfinite(x), guess), axis=1)
)[0]
common_length = min(len(finite_guess), len(non_finite_index))
initial_position = initial_position.at[non_finite_index[:common_length]].set(guess[:common_length])
initial_position = initial_position.at[
non_finite_index[:common_length]
].set(guess[:common_length])
self.sampler.sample(initial_position, None) # type: ignore

def maximize_likelihood(
Expand Down Expand Up @@ -157,7 +176,7 @@ def print_summary(self, transform: bool = True):
training_chain = self.add_name(training_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
training_chain = sample_transform.backward(training_chain)
training_chain = jax.vmap(sample_transform.backward)(training_chain)
training_log_prob = train_summary["log_prob"]
training_local_acceptance = train_summary["local_accs"]
training_global_acceptance = train_summary["global_accs"]
Expand All @@ -167,7 +186,7 @@ def print_summary(self, transform: bool = True):
production_chain = self.add_name(production_chain)
if transform:
for sample_transform in reversed(self.sample_transforms):
production_chain = sample_transform.backward(production_chain)
production_chain = jax.vmap(sample_transform.backward)(production_chain)
production_log_prob = production_summary["log_prob"]
production_local_acceptance = production_summary["local_accs"]
production_global_acceptance = production_summary["global_accs"]
Expand Down Expand Up @@ -223,10 +242,10 @@ def get_samples(self, training: bool = False) -> dict:
else:
chains = self.sampler.get_sampler_state(training=False)["chains"]

chains = chains.transpose(2, 0, 1)
chains = chains.reshape(-1, self.prior.n_dim)
chains = self.add_name(chains)
for sample_transform in reversed(self.sample_transforms):
chains = sample_transform.backward(chains)
chains = jax.vmap(sample_transform.backward)(chains)
return chains

def plot(self):
Expand Down
9 changes: 9 additions & 0 deletions src/jimgw/single_event/likelihood.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,15 @@ def __init__(self, detectors: list[Detector], waveform: Waveform) -> None:
self.waveform = waveform


class ZeroLikelihood(LikelihoodBase):

def __init__(self):
pass

def evaluate(self, params: dict[str, Float], data: dict) -> Float:
return 0.0


class TransientLikelihoodFD(SingleEventLiklihood):
def __init__(
self,
Expand Down
Loading
Loading