From aa1274ea87f826ca79abf28fe6ee99d867da28ff Mon Sep 17 00:00:00 2001 From: Thomas Ng Date: Mon, 16 Sep 2024 12:19:12 +0800 Subject: [PATCH] Update likelihood.py --- src/jimgw/single_event/likelihood.py | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/jimgw/single_event/likelihood.py b/src/jimgw/single_event/likelihood.py index f8692af4..e8b4e8af 100644 --- a/src/jimgw/single_event/likelihood.py +++ b/src/jimgw/single_event/likelihood.py @@ -558,14 +558,14 @@ def maximize_likelihood( ): parameter_names = prior.parameter_names for transform in sample_transforms: - parameter_names = transform.propagate_name(parameter_names) + parameter_names = jax.vmap(transform.propagate_name)(parameter_names) def y(x): named_params = dict(zip(parameter_names, x)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return -self.evaluate_original(named_params, {}) print("Starting the optimizer") @@ -605,9 +605,9 @@ def y(x): named_params = dict(zip(parameter_names, best_fit)) for transform in reversed(sample_transforms): - named_params = transform.backward(named_params) + named_params = jax.vmap(transform.backward)(named_params) for transform in likelihood_transforms: - named_params = transform.forward(named_params) + named_params = jax.vmap(transform.forward)(named_params) return named_params