Skip to content

Commit

Permalink
Added deterministics callback
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayam753 committed Jul 17, 2020
1 parent ae4edf1 commit c80d3a7
Showing 1 changed file with 22 additions and 20 deletions.
42 changes: 22 additions & 20 deletions pymc4/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,6 @@

from pymc4 import flow
from pymc4.coroutine_model import Model
from pymc4.distributions.transforms import JacobianPreference
from pymc4.inference.utils import initialize_sampling_state
from pymc4.utils import NameParts
from pymc4.variational import updates
Expand Down Expand Up @@ -42,25 +41,38 @@ def __init__(self, model: Optional[Model] = None, random_seed: Optional[int] = N

self.order = ArrayOrdering(self.state.all_unobserved_values)
self.unobserved_keys = self.state.all_unobserved_values.keys()
self.target_log_prob = self._build_logfn()
self.target_log_prob, self.deterministics_callback = self._build_logp_and_deterministic_fn()
self.approx = self._build_posterior()

def _build_logfn(self):
"""Build vectorized logp function."""
def _build_logp_and_deterministic_fn(self):
"""Build vectorized logp and deterministic functions."""

@tf.function(autograph=False)
def logpfn(*values):
split_view = self.order.split(values[0])
_, st = flow.evaluate_meta_model(self.model, values=split_view)
return st.collect_log_prob()

def vectorize_logp_function(logpfn):
def vectorized_logpfn(*q_samples):
return tf.vectorized_map(lambda samples: logpfn(*samples), q_samples)
@tf.function(autograph=False)
def deterministics_callback(q_samples):
st = flow.SamplingState.from_values(
q_samples, observed_values=self.state.observed_values
)
_, st = flow.evaluate_model_transformed(self.model, state=st)
for transformed_name in st.transformed_values:
untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name
st.deterministics[untransformed_name] = st.untransformed_values.pop(
untransformed_name
)
return st.deterministics

return vectorized_logpfn
def vectorize_function(function):
def vectorizedfn(*q_samples):
return tf.vectorized_map(lambda samples: function(*samples), q_samples)

return vectorize_logp_function(logpfn)
return vectorizedfn

return vectorize_function(logpfn), vectorize_function(deterministics_callback)

def _build_posterior(self):
raise NotImplementedError
Expand All @@ -69,17 +81,7 @@ def sample(self, n: int = 500) -> az.InferenceData:
"""Generate samples from posterior distribution."""
samples = self.approx.sample(n)
q_samples = self.order.split_samples(samples, n)

# TODO - Account for deterministics as well.
# For all transformed_variables, apply inverse of bijector to sampled values to match support in constraint space.
_, st = flow.evaluate_model(self.model)
for transformed_name in self.state.transformed_values:
untransformed_name = NameParts.from_name(transformed_name).full_untransformed_name
transform = st.distributions[untransformed_name].transform
if transform.JacobianPreference == JacobianPreference.Forward:
q_samples[untransformed_name] = transform.forward(q_samples[transformed_name])
else:
q_samples[untransformed_name] = transform.inverse(q_samples[transformed_name])
q_samples = dict(**q_samples, **self.deterministics_callback(q_samples))

# Add a new axis so as n_chains=1 for InferenceData: handles shape issues
trace = {k: v.numpy()[np.newaxis] for k, v in q_samples.items()}
Expand Down

0 comments on commit c80d3a7

Please sign in to comment.