diff --git a/pymc4/variational/approximations.py b/pymc4/variational/approximations.py index 800f6ed5..1c8b7ecf 100644 --- a/pymc4/variational/approximations.py +++ b/pymc4/variational/approximations.py @@ -10,7 +10,6 @@ from pymc4 import flow from pymc4.coroutine_model import Model -from pymc4.distributions.transforms import JacobianPreference from pymc4.inference.utils import initialize_sampling_state from pymc4.utils import NameParts from pymc4.variational import updates @@ -42,11 +41,11 @@ def __init__(self, model: Optional[Model] = None, random_seed: Optional[int] = N self.order = ArrayOrdering(self.state.all_unobserved_values) self.unobserved_keys = self.state.all_unobserved_values.keys() - self.target_log_prob = self._build_logfn() + self.target_log_prob, self.deterministics_callback = self._build_logp_and_deterministic_fn() self.approx = self._build_posterior() - def _build_logfn(self): - """Build vectorized logp function.""" + def _build_logp_and_deterministic_fn(self): + """Build vectorized logp and deterministic functions.""" @tf.function(autograph=False) def logpfn(*values): @@ -54,13 +53,26 @@ def logpfn(*values): _, st = flow.evaluate_meta_model(self.model, values=split_view) return st.collect_log_prob() - def vectorize_logp_function(logpfn): - def vectorized_logpfn(*q_samples): - return tf.vectorized_map(lambda samples: logpfn(*samples), q_samples) + @tf.function(autograph=False) + def deterministics_callback(q_samples): + st = flow.SamplingState.from_values( + q_samples, observed_values=self.state.observed_values + ) + _, st = flow.evaluate_model_transformed(self.model, state=st) + for transformed_name in st.transformed_values: + untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name + st.deterministics[untransformed_name] = st.untransformed_values.pop( + untransformed_name + ) + return st.deterministics - return vectorized_logpfn + def vectorize_function(function): + def vectorizedfn(*q_samples): + return tf.vectorized_map(lambda samples: function(*samples), q_samples) - return vectorize_logp_function(logpfn) + return vectorizedfn + + return vectorize_function(logpfn), vectorize_function(deterministics_callback) def _build_posterior(self): raise NotImplementedError @@ -69,17 +81,7 @@ def sample(self, n: int = 500) -> az.InferenceData: """Generate samples from posterior distribution.""" samples = self.approx.sample(n) q_samples = self.order.split_samples(samples, n) - - # TODO - Account for deterministics as well. - # For all transformed_variables, apply inverse of bijector to sampled values to match support in constraint space. - _, st = flow.evaluate_model(self.model) - for transformed_name in self.state.transformed_values: - untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name - transform = st.distributions[untransformed_name].transform - if transform.JacobianPreference == JacobianPreference.Forward: - q_samples[untransformed_name] = transform.forward(q_samples[transformed_name]) - else: - q_samples[untransformed_name] = transform.inverse(q_samples[transformed_name]) + q_samples = dict(**q_samples, **self.deterministics_callback(q_samples)) # Add a new axis so as n_chains=1 for InferenceData: handles shape issues trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()}