Skip to content

Commit

Permalink
runs through sbc stuff, still need to do 1:1 plots
Browse files Browse the repository at this point in the history
  • Loading branch information
beckynevin committed Jan 29, 2024
1 parent ee0a7dd commit 6e9eb60
Showing 1 changed file with 8 additions and 9 deletions.
17 changes: 8 additions & 9 deletions src/scripts/evaluate.py
Original file line number Diff line number Diff line change
Expand Up @@ -235,9 +235,12 @@ def calculate_coverage_fraction(self,
all_samples = np.empty((len(ys), samples_per_inference, np.shape(thetas)[1]))
count_array = []
# make this for loop into a progress bar:
for i in tqdm(range(len(ys)), desc='Processing observations', unit='obs'):
for i in tqdm(range(len(ys)), desc='Sampling from the posterior for each obs', unit='obs'):
#for i in range(len(ys)):
# sample from the trained posterior n_sample times for each observation
samples = posterior.sample(sample_shape=(samples_per_inference,), x=ys[i]).cpu()
samples = posterior.sample(sample_shape=(samples_per_inference,),
x=ys[i],
show_progress_bars=False).cpu()

'''
# plot posterior samples
Expand Down Expand Up @@ -394,11 +397,7 @@ def run_all_sbc(self,
save=save,
path=path)






def parameter_1_to_1_plots(samples,):
'''
We've already saved samples, let's compare the inferred (and associated error bar) parameters from each of the data points we used for the SBC analysis.
'''
Expand Down Expand Up @@ -428,11 +427,11 @@ def run_all_sbc(self,
yerr_minus[idx] = 0
yerr_plus[idx] = 0

plt.errorbar(np.array(thetas[:,0]),
plt.errorbar(np.array(thetas[:,i]),
percentile_50_m,
yerr = [yerr_minus, yerr_plus],
linestyle = 'None',
color = m_color,
color = color_list[i],
capsize = 5)
plt.scatter(np.array(thetas[:,0]), percentile_50_m, color = m_color)
plt.plot(percentile_50_m, percentile_50_m, color = 'k')
Expand Down

0 comments on commit 6e9eb60

Please sign in to comment.