Skip to content

Commit

Permalink
Changed dtype_hint to tf.float64 for numerical stability
Browse files Browse the repository at this point in the history
  • Loading branch information
Sayam753 committed Jul 27, 2020
1 parent 109eda8 commit 5137396
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pymc4/variational/approximations.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,7 @@ class MeanField(Approximation):
def _build_posterior(self):
flattened_shape = self.order.size
dtype = dtype_util.common_dtype(
self.state.all_unobserved_values.values(), dtype_hint=tf.float32
self.state.all_unobserved_values.values(), dtype_hint=tf.float64
)
loc = tf.Variable(tf.random.normal([flattened_shape], dtype=dtype), name="mu")
cov_param = tfp.util.TransformedVariable(
Expand Down

0 comments on commit 5137396

Please sign in to comment.