Skip to content

Commit

Permalink
cdf plot
Browse files Browse the repository at this point in the history
cdf plot
  • Loading branch information
mjhajharia committed Jul 27, 2022
1 parent 974b8e5 commit ef1ef4e
Show file tree
Hide file tree
Showing 5 changed files with 48 additions and 5 deletions.
Binary file added figures/simplex/ess_cdf.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
File renamed without changes
13 changes: 10 additions & 3 deletions utils/ess.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ def get_ess_leapfrog_ratio(
var_name,
var_dim,
n_repeat=100,
plot_type='density'
):
x = []
idata = sample(
Expand All @@ -38,7 +39,13 @@ def get_ess_leapfrog_ratio(
ess = np.loadtxt(open(f'/mnt/sdceph/users/mjhajaria/sampling_results/{transform_category}/{transform}/{evaluating_model}/ess_{param_map[tuple(list(params.values())[0])]}_{n_repeat}.csv'),delimiter = ",")
leapfrog = np.average(idata.sample_stats['n_steps'].sum(axis=1).values.reshape(-1, 4), axis=1)
x=np.divide(ess, leapfrog)
kde = gaussian_kde(x)
dist_space = np.linspace(min(x), max(x), 1000)
return dist_space, kde(dist_space)
if plot_type=='density':
kde = gaussian_kde(x)
dist_space = np.linspace(min(x), max(x), 1000)
return dist_space, kde(dist_space)
if plot_type=='cdf':
count, bins_count = np.histogram(x, bins=10)
pdf = count / sum(count)
cdf = np.cumsum(pdf)
return cdf, bins_count[1:]

36 changes: 36 additions & 0 deletions utils/ess_cdf.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,36 @@
import os
os.chdir('..')
import sys
sys.path.insert(1, 'utils')
from ess import get_ess_leapfrog_ratio

import pickle
import numpy as np
import matplotlib.pyplot as plt
import arviz as az



parameters = [{'alpha':[0.1]*10, 'N':10}, {'alpha':[0.1]*100, 'N':100}, {'alpha': [0.1]*1000, 'N': 1000},
{'alpha':[1]*10, 'N':10}, {'alpha':[1]*100, 'N':100}, {'alpha': [1]*1000, 'N': 1000},
{'alpha':[10]*10, 'N':10}, {'alpha':[10]*100, 'N':100}, {'alpha': [1]*1000, 'N': 1000}]

transforms = ['stickbreaking', 'softmax', 'softmax-augmented', 'stan']
transform_category='simplex'
evaluating_model='dirichlet_symmetric'

var_name='x'
var_dim=0

plt.rcParams["figure.figsize"] = (20,10)
fig, axes = plt.subplots(3,3)
for ax, params in zip(axes.flatten() if len(parameters)>1 else [axes], parameters):
for transform in transforms:
x, y = get_ess_leapfrog_ratio(transform_category, transform, evaluating_model, params, var_name, var_dim, n_repeat=100, plot_type='cdf')
ax.plot(x,y, label=transform)
ax.set_title(f'alpha = {params["alpha"][0]}, N = {params["N"]}')
ax.axes.yaxis.set_ticklabels([])
fig.supxlabel('ESS/Leapfrog')
fig.supylabel('Cumulative Density')
plt.legend()
plt.savefig('figures/simplex/ess_cdf.png', dpi=300)
4 changes: 2 additions & 2 deletions utils/ess_plot_script.py → utils/ess_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,11 +26,11 @@
fig, axes = plt.subplots(3,3)
for ax, params in zip(axes.flatten() if len(parameters)>1 else [axes], parameters):
for transform in transforms:
x, y = get_ess_leapfrog_ratio(transform_category, transform, evaluating_model, params, var_name, var_dim, n_repeat=100)
x, y = get_ess_leapfrog_ratio(transform_category, transform, evaluating_model, params, var_name, var_dim, n_repeat=100, plot_type='density')
ax.plot(x,y, label=transform)
ax.set_title(f'alpha = {params["alpha"][0]}, N = {params["N"]}')
ax.axes.yaxis.set_ticklabels([])
fig.supxlabel('ESS/Leapfrog')
fig.supylabel('Density')
plt.legend()
plt.savefig('figures/simplex/ess.png', dpi=300)
plt.savefig('figures/simplex/ess_density.png', dpi=300)

0 comments on commit ef1ef4e

Please sign in to comment.