From dfed9ebb3feab6eaa5ea4c00069b8c5916de58e6 Mon Sep 17 00:00:00 2001 From: mraveri Date: Thu, 26 Sep 2024 05:30:30 -0700 Subject: [PATCH] update --- tensiometer/mcmc_tension/flow.py | 8 ++++---- .../synthetic_probability/synthetic_probability.py | 1 + 2 files changed, 5 insertions(+), 4 deletions(-) diff --git a/tensiometer/mcmc_tension/flow.py b/tensiometer/mcmc_tension/flow.py index a9121c3..329868e 100644 --- a/tensiometer/mcmc_tension/flow.py +++ b/tensiometer/mcmc_tension/flow.py @@ -38,9 +38,9 @@ def estimate_shift(flow, prior_flow=None, tol=0.05, max_iter=1000, step=100000): counter = max_iter # define threshold for tension calculation: - _thres = flow.log_probability(flow.cast(np.zeros(flow.num_params))) + _thres = flow.log_probability(flow.cast([np.zeros(flow.num_params)]))[0] if prior_flow is not None: - _thres = _thres - prior_flow.log_probability(prior_flow.cast(np.zeros(prior_flow.num_params))) + _thres = _thres - prior_flow.log_probability(prior_flow.cast([np.zeros(prior_flow.num_params)]))[0] _num_filtered = 0 _num_samples = 0 @@ -89,9 +89,9 @@ def estimate_shift_from_samples(flow, prior_flow=None): """ # define threshold for tension calculation: - _thres = flow.log_probability(flow.cast(np.zeros(flow.num_params))) + _thres = flow.log_probability(flow.cast([np.zeros(flow.num_params)]))[0] if prior_flow is not None: - _thres = _thres - prior_flow.log_probability(prior_flow.cast(np.zeros(prior_flow.num_params))) + _thres = _thres - prior_flow.log_probability(prior_flow.cast([np.zeros(prior_flow.num_params)]))[0] # calculate probability on the samples: _s_prob = flow.log_probability(flow.cast(flow.chain_samples)) diff --git a/tensiometer/synthetic_probability/synthetic_probability.py b/tensiometer/synthetic_probability/synthetic_probability.py index 559245c..1146f12 100644 --- a/tensiometer/synthetic_probability/synthetic_probability.py +++ b/tensiometer/synthetic_probability/synthetic_probability.py @@ -2256,6 +2256,7 @@ def __init__(self, flows, **kwargs): 'param_names', 'param_labels', 'parameter_ranges', + 'periodic_params', 'chain_samples', 'chain_loglikes', 'has_loglikes',