Skip to content

Commit

Permalink
update
Browse files Browse the repository at this point in the history
  • Loading branch information
mraveri committed Nov 14, 2024
1 parent 39d7223 commit 9e296ae
Showing 1 changed file with 19 additions and 1 deletion.
20 changes: 19 additions & 1 deletion tensiometer/synthetic_probability/synthetic_probability.py
Original file line number Diff line number Diff line change
Expand Up @@ -494,8 +494,12 @@ def _init_trainable_bijector(self, trainable_bijector, trainable_bijector_path=N
# load from file:
if trainable_bijector_path is not None:
if self.trainable_transformation is not None:
if self.feedback > 1:
print(' - loading trainable bijector from file:', trainable_bijector_path)
self.trainable_transformation = self.trainable_transformation.load(trainable_bijector_path,
**kwargs)
else:
raise ValueError('Cannot load a bijector from file if the trainable bijector is not a TrainableTransformation')

# initialize bijector:
if self.trainable_transformation is not None:
Expand Down Expand Up @@ -1029,11 +1033,23 @@ def MCSamples(self, size, logLikes=True, **kwargs):
:param size: number of samples
:param logLikes: logical, whether to include log-likelihoods or not.
"""
# sample:
samples = self.sample(size)
finite_filter = tf.reduce_all(tf.math.is_finite(samples), axis=-1)
if logLikes:
loglikes = -self.log_probability(samples)
finite_filter = tf.math.logical_and(finite_filter, tf.math.is_finite(loglikes))
else:
loglikes = None
# filter out non-finite values:
if not np.all(finite_filter):
samples = samples[finite_filter]
if loglikes is not None:
loglikes = loglikes[finite_filter]
# feedback:
if self.feedback > 0:
print(' - found non-finite values, filtering out {0} samples'.format(size - len(samples)))
# create MCSamples object:
mc_samples = MCSamples(
samples=samples.numpy(),
loglikes=loglikes.numpy(),
Expand Down Expand Up @@ -2353,6 +2369,8 @@ def training_plot(self, logs=None, file_path=None, ipython_plotting=False):
file_path=_temp_name,
ipython_plotting=ipython_plotting,
title='Training flow '+str(_i))
if ipython_plotting:
plt.show()
#
return None

Expand Down Expand Up @@ -2574,7 +2592,7 @@ def average_flow_from_chain(chain, num_flows=1, cache_dir=None, root_name='sprob
if use_mpi and size > 1:
pass
else:
if feedback > 1:
if feedback > 0:
print('Loading flow', i, 'from cache', flush=use_mpi)
flow = FlowCallback.load(chain, _outroot, **kwargs)
flows.append(flow)
Expand Down

0 comments on commit 9e296ae

Please sign in to comment.