Skip to content

Commit

Permalink
Numpyro quasiposterior for the quasimultinomial model (#32)
Browse files Browse the repository at this point in the history
* numpyro quasiposterior

* Attempt fixing _isscalar() function

* Expand the notebook to show how to use quasiposterior

* Remove unsupported notebook.

---------

Co-authored-by: dr-david <[email protected]>
  • Loading branch information
pawel-czyz and dr-david authored Nov 22, 2024
1 parent f739fea commit c3fe7e0
Show file tree
Hide file tree
Showing 2 changed files with 270 additions and 1 deletion.
263 changes: 263 additions & 0 deletions examples/frequentist_notebook_jax.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,3 +288,266 @@ def format_date(x, pos):


figure_spec.map(plot_city, range(len(cities)))
# -

# ## Quasiposterior modelling
#
# Above we fitted the model using the maximum quasilikelihood approach, and then constructed approximate confidence intervals basing on the assumed covariance matrix structure and adjusting it by the estimated overdispersion factor.
# There exists also another method of quantifying uncertainty, which is based on generalized Bayesian paradigm, where the likelihood is replaced by the quasilikelihood.
#
# These methods of quantifying uncertainty do not have to be necessarily compatible and may reveal that the quasiposterior on growth advantage estimates is e.g., not symmetric.
#
# In fact, we attempt to use separate overdispersion for each city. Let's compare both approaches.

# +
import arviz as az
from numpyro.infer import MCMC, NUTS
from functools import partial


def sample_from_model(share_overdispersion: bool):
if share_overdispersion:
_overdispersion = overdispersion_tuple.overall
else:
_overdispersion = overdispersion_tuple.cities

model = qm.construct_model(
ys=ys_effective,
ts=ts_lst_scaled,
overdispersion=_overdispersion,
sigma_offset=100.0,
)

mcmc = MCMC(NUTS(model), num_chains=4, num_samples=2000, num_warmup=2000)
mcmc.run(jax.random.PRNGKey(42))
return mcmc


mcmc_shared = sample_from_model(share_overdispersion=True)
mcmc_indivi = sample_from_model(share_overdispersion=False)
# -

# Before we proceed with the analysis of the quasiposteriors, let's see if we can trust the obtained samples.
#
# **Shared overdispersion**

idata = az.from_numpyro(mcmc_shared)
az.summary(idata, filter_vars="regex", var_names="^r.*")

az.plot_trace(idata, filter_vars="regex", var_names="^r.*")
plt.tight_layout()
plt.show()

# **Individual overdispersion parameters**

idata = az.from_numpyro(mcmc_indivi)
az.summary(idata, filter_vars="regex", var_names="^r.*")

az.plot_trace(idata, filter_vars="regex", var_names="^r.*")
plt.tight_layout()
plt.show()

# If we do not see sampling problems, we can try to understand the quasiposterior distributions.
#
# Let's compare both quasiposteriors additionally with the confidence intervals.

# +
from subplots_from_axsize import subplots_from_axsize


def plot_posterior(ax, i, mcmc):
max_quasilikelihood = qm.get_relative_growths(
theta_star, n_variants=n_variants_effective
)
lower = qm.get_relative_growths(
confints_estimates[0], n_variants=n_variants_effective
)
upper = qm.get_relative_growths(
confints_estimates[1], n_variants=n_variants_effective
)

# Plot maximum quasilikelihood and confidence interval bands
ax.axvline(max_quasilikelihood[i], c="k")
ax.axvspan(lower[i], upper[i], alpha=0.3, facecolor="k", edgecolor=None)

# Plot quasiposterior samples using a histogram
samples = mcmc.get_samples()["relative_growths"][:, i]
ax.hist(samples, bins=40, color="maroon")

# Plot the credible interval calculated using quantiles
credibility = 0.95
_a = (1 - credibility) / 2.0
ax.axvline(jnp.quantile(samples, q=_a), c="maroon", linestyle=":")
ax.axvline(jnp.quantile(samples, q=1.0 - _a), c="maroon", linestyle=":")

# Apply some styling
ax.spines[["left", "right", "top"]].set_visible(False)
ax.set_yticks([])


fig, axs = subplots_from_axsize(
ncols=n_variants_effective - 1,
axsize=(2, 0.8),
nrows=2,
sharex="col",
hspace=0.25,
dpi=400,
)

for i in range(n_variants_effective - 1):
plot_posterior(axs[0, i], i, mcmc_shared)
plot_posterior(axs[1, i], i, mcmc_indivi)

axs[0, 0].set_ylabel("Shared")
axs[1, 0].set_ylabel("Individual")

for i, variant in enumerate(variants_effective[1:]):
axs[0, i].set_title(f"Advantage of {variant}")


# -

# We see two things:
#
# - Quasiposterior employing shared overdispersion gives similar results to the ones obtained with confidence intervals.
# - When we use individual overdispersion factors (one per city), we see a discrepancy.
#
# Let's compare the predictive plots between two quasiposteriors and the confidence bands obtained earlier.


# +
def plot_predictions(
ax,
i: int,
*,
fitted_line,
fitted_lower,
fitted_upper,
predicted_line,
predicted_lower,
predicted_upper,
) -> None:
def remove_0th(arr):
"""We don't plot the artificial 0th variant 'other'."""
return arr[:, 1:]

# Plot fits in observed and unobserved time intervals.
plot_ts.plot_fit(ax, ts_lst[i], remove_0th(fitted_line[i]), colors=colors)
plot_ts.plot_fit(
ax, ts_pred_lst[i], remove_0th(predicted_line[i]), colors=colors, linestyle="--"
)
plot_ts.plot_confidence_bands(
ax,
ts_lst[i],
(remove_0th(fitted_lower[i]), remove_0th(fitted_upper[i])),
colors=colors,
)
plot_ts.plot_confidence_bands(
ax,
ts_pred_lst[i],
(remove_0th(predicted_lower[i]), remove_0th(predicted_upper[i])),
colors=colors,
)

# Plot the data points
plot_ts.plot_data(ax, ts_lst[i], remove_0th(ys_effective[i]), colors=colors)

# Plot the complements
plot_ts.plot_complement(ax, ts_lst[i], remove_0th(fitted_line[i]), alpha=0.3)
plot_ts.plot_complement(
ax, ts_pred_lst[i], remove_0th(predicted_line[i]), linestyle="--", alpha=0.3
)

# format axes and title
def format_date(x, pos):
return plot_ts.num_to_date(x, date_min=start_date)

date_formatter = ticker.FuncFormatter(format_date)
ax.xaxis.set_major_formatter(date_formatter)
tick_positions = [0, 0.5, 1]
tick_labels = ["0%", "50%", "100%"]
ax.set_yticks(tick_positions)
ax.set_yticklabels(tick_labels)


fig, axs = subplots_from_axsize(
ncols=3,
axsize=(2, 0.8),
nrows=len(cities),
sharex=True,
sharey=True,
hspace=0.4,
dpi=400,
)

for i, city in enumerate(cities):
axs[i, 0].set_ylabel(city)

for ax, name in zip(
axs[0, :], ["Confidence", "Credible (shared)", "Credible (individual)"]
):
ax.set_title(name)


# Plot the quasilikelihood fits
for i, ax in enumerate(axs[:, 0]):
plot_predictions(
ax,
i,
fitted_line=ys_fitted,
fitted_lower=[y.lower for y in ys_fitted_confint],
fitted_upper=[y.upper for y in ys_fitted_confint],
predicted_line=ys_pred,
predicted_lower=[y.lower for y in ys_pred_confint],
predicted_upper=[y.upper for y in ys_pred_confint],
)


# Plot the quasiposterior with shared MCMC


def obtain_predictions(mcmc, _a=0.05):
def get_fit(sample):
theta = qm.construct_theta(
relative_growths=sample["relative_growths"],
relative_midpoints=sample["relative_offsets"],
)

y_fit = qm.fitted_values(
ts_lst_scaled, theta, cities=cities, n_variants=n_variants_effective
)
y_pre = qm.fitted_values(
ts_pred_lst_scaled, theta, cities=cities, n_variants=n_variants_effective
)
return y_fit, y_pre

def get_line(ys):
return jnp.mean(ys, axis=0)

def get_lower(ys):
return jnp.quantile(ys, q=_a / 2, axis=0)

def get_upper(ys):
return jnp.quantile(ys, q=1 - _a / 2, axis=0)

# Apply some thinning for computational speedup
samples = jax.tree.map(lambda x: x[::10, ...], mcmc.get_samples())

fits, preds = jax.vmap(get_fit)(samples)
return dict(
fitted_line=jax.tree.map(get_line, fits),
fitted_lower=jax.tree.map(get_lower, fits),
fitted_upper=jax.tree.map(get_upper, fits),
predicted_line=jax.tree.map(get_line, preds),
predicted_lower=jax.tree.map(get_lower, preds),
predicted_upper=jax.tree.map(get_upper, preds),
)


for i, ax in enumerate(axs[:, 1]):
plot_predictions(ax, i, **obtain_predictions(mcmc_shared))

# Plot individual overdispersions
for i, ax in enumerate(axs[:, 2]):
plot_predictions(ax, i, **obtain_predictions(mcmc_indivi))
# -
8 changes: 7 additions & 1 deletion src/covvfit/_padding.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,13 @@


def _is_scalar(value) -> bool:
return not hasattr(value, "__len__")
try:
length = len(value)
if length != 0:
return False
return True
except TypeError:
return True


def create_padded_array(
Expand Down

0 comments on commit c3fe7e0

Please sign in to comment.