diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index da60c05..a7eaa35 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -288,3 +288,266 @@ def format_date(x, pos): figure_spec.map(plot_city, range(len(cities))) +# - + +# ## Quasiposterior modelling +# +# Above we fitted the model using the maximum quasilikelihood approach, and then constructed approximate confidence intervals basing on the assumed covariance matrix structure and adjusting it by the estimated overdispersion factor. +# There exists also another method of quantifying uncertainty, which is based on generalized Bayesian paradigm, where the likelihood is replaced by the quasilikelihood. +# +# These methods of quantifying uncertainty do not have to be necessarily compatible and may reveal that the quasiposterior on growth advantage estimates is e.g., not symmetric. +# +# In fact, we attempt to use separate overdispersion for each city. Let's compare both approaches. + +# + +import arviz as az +from numpyro.infer import MCMC, NUTS +from functools import partial + + +def sample_from_model(share_overdispersion: bool): + if share_overdispersion: + _overdispersion = overdispersion_tuple.overall + else: + _overdispersion = overdispersion_tuple.cities + + model = qm.construct_model( + ys=ys_effective, + ts=ts_lst_scaled, + overdispersion=_overdispersion, + sigma_offset=100.0, + ) + + mcmc = MCMC(NUTS(model), num_chains=4, num_samples=2000, num_warmup=2000) + mcmc.run(jax.random.PRNGKey(42)) + return mcmc + + +mcmc_shared = sample_from_model(share_overdispersion=True) +mcmc_indivi = sample_from_model(share_overdispersion=False) +# - + +# Before we proceed with the analysis of the quasiposteriors, let's see if we can trust the obtained samples. +# +# **Shared overdispersion** + +idata = az.from_numpyro(mcmc_shared) +az.summary(idata, filter_vars="regex", var_names="^r.*") + +az.plot_trace(idata, filter_vars="regex", var_names="^r.*") +plt.tight_layout() +plt.show() + +# **Individual overdispersion parameters** + +idata = az.from_numpyro(mcmc_indivi) +az.summary(idata, filter_vars="regex", var_names="^r.*") + +az.plot_trace(idata, filter_vars="regex", var_names="^r.*") +plt.tight_layout() +plt.show() + +# If we do not see sampling problems, we can try to understand the quasiposterior distributions. +# +# Let's compare both quasiposteriors additionally with the confidence intervals. + +# + +from subplots_from_axsize import subplots_from_axsize + + +def plot_posterior(ax, i, mcmc): + max_quasilikelihood = qm.get_relative_growths( + theta_star, n_variants=n_variants_effective + ) + lower = qm.get_relative_growths( + confints_estimates[0], n_variants=n_variants_effective + ) + upper = qm.get_relative_growths( + confints_estimates[1], n_variants=n_variants_effective + ) + + # Plot maximum quasilikelihood and confidence interval bands + ax.axvline(max_quasilikelihood[i], c="k") + ax.axvspan(lower[i], upper[i], alpha=0.3, facecolor="k", edgecolor=None) + + # Plot quasiposterior samples using a histogram + samples = mcmc.get_samples()["relative_growths"][:, i] + ax.hist(samples, bins=40, color="maroon") + + # Plot the credible interval calculated using quantiles + credibility = 0.95 + _a = (1 - credibility) / 2.0 + ax.axvline(jnp.quantile(samples, q=_a), c="maroon", linestyle=":") + ax.axvline(jnp.quantile(samples, q=1.0 - _a), c="maroon", linestyle=":") + + # Apply some styling + ax.spines[["left", "right", "top"]].set_visible(False) + ax.set_yticks([]) + + +fig, axs = subplots_from_axsize( + ncols=n_variants_effective - 1, + axsize=(2, 0.8), + nrows=2, + sharex="col", + hspace=0.25, + dpi=400, +) + +for i in range(n_variants_effective - 1): + plot_posterior(axs[0, i], i, mcmc_shared) + plot_posterior(axs[1, i], i, mcmc_indivi) + +axs[0, 0].set_ylabel("Shared") +axs[1, 0].set_ylabel("Individual") + +for i, variant in enumerate(variants_effective[1:]): + axs[0, i].set_title(f"Advantage of {variant}") + + +# - + +# We see two things: +# +# - Quasiposterior employing shared overdispersion gives similar results to the ones obtained with confidence intervals. +# - When we use individual overdispersion factors (one per city), we see a discrepancy. +# +# Let's compare the predictive plots between two quasiposteriors and the confidence bands obtained earlier. + + +# + +def plot_predictions( + ax, + i: int, + *, + fitted_line, + fitted_lower, + fitted_upper, + predicted_line, + predicted_lower, + predicted_upper, +) -> None: + def remove_0th(arr): + """We don't plot the artificial 0th variant 'other'.""" + return arr[:, 1:] + + # Plot fits in observed and unobserved time intervals. + plot_ts.plot_fit(ax, ts_lst[i], remove_0th(fitted_line[i]), colors=colors) + plot_ts.plot_fit( + ax, ts_pred_lst[i], remove_0th(predicted_line[i]), colors=colors, linestyle="--" + ) + plot_ts.plot_confidence_bands( + ax, + ts_lst[i], + (remove_0th(fitted_lower[i]), remove_0th(fitted_upper[i])), + colors=colors, + ) + plot_ts.plot_confidence_bands( + ax, + ts_pred_lst[i], + (remove_0th(predicted_lower[i]), remove_0th(predicted_upper[i])), + colors=colors, + ) + + # Plot the data points + plot_ts.plot_data(ax, ts_lst[i], remove_0th(ys_effective[i]), colors=colors) + + # Plot the complements + plot_ts.plot_complement(ax, ts_lst[i], remove_0th(fitted_line[i]), alpha=0.3) + plot_ts.plot_complement( + ax, ts_pred_lst[i], remove_0th(predicted_line[i]), linestyle="--", alpha=0.3 + ) + + # format axes and title + def format_date(x, pos): + return plot_ts.num_to_date(x, date_min=start_date) + + date_formatter = ticker.FuncFormatter(format_date) + ax.xaxis.set_major_formatter(date_formatter) + tick_positions = [0, 0.5, 1] + tick_labels = ["0%", "50%", "100%"] + ax.set_yticks(tick_positions) + ax.set_yticklabels(tick_labels) + + +fig, axs = subplots_from_axsize( + ncols=3, + axsize=(2, 0.8), + nrows=len(cities), + sharex=True, + sharey=True, + hspace=0.4, + dpi=400, +) + +for i, city in enumerate(cities): + axs[i, 0].set_ylabel(city) + +for ax, name in zip( + axs[0, :], ["Confidence", "Credible (shared)", "Credible (individual)"] +): + ax.set_title(name) + + +# Plot the quasilikelihood fits +for i, ax in enumerate(axs[:, 0]): + plot_predictions( + ax, + i, + fitted_line=ys_fitted, + fitted_lower=[y.lower for y in ys_fitted_confint], + fitted_upper=[y.upper for y in ys_fitted_confint], + predicted_line=ys_pred, + predicted_lower=[y.lower for y in ys_pred_confint], + predicted_upper=[y.upper for y in ys_pred_confint], + ) + + +# Plot the quasiposterior with shared MCMC + + +def obtain_predictions(mcmc, _a=0.05): + def get_fit(sample): + theta = qm.construct_theta( + relative_growths=sample["relative_growths"], + relative_midpoints=sample["relative_offsets"], + ) + + y_fit = qm.fitted_values( + ts_lst_scaled, theta, cities=cities, n_variants=n_variants_effective + ) + y_pre = qm.fitted_values( + ts_pred_lst_scaled, theta, cities=cities, n_variants=n_variants_effective + ) + return y_fit, y_pre + + def get_line(ys): + return jnp.mean(ys, axis=0) + + def get_lower(ys): + return jnp.quantile(ys, q=_a / 2, axis=0) + + def get_upper(ys): + return jnp.quantile(ys, q=1 - _a / 2, axis=0) + + # Apply some thinning for computational speedup + samples = jax.tree.map(lambda x: x[::10, ...], mcmc.get_samples()) + + fits, preds = jax.vmap(get_fit)(samples) + return dict( + fitted_line=jax.tree.map(get_line, fits), + fitted_lower=jax.tree.map(get_lower, fits), + fitted_upper=jax.tree.map(get_upper, fits), + predicted_line=jax.tree.map(get_line, preds), + predicted_lower=jax.tree.map(get_lower, preds), + predicted_upper=jax.tree.map(get_upper, preds), + ) + + +for i, ax in enumerate(axs[:, 1]): + plot_predictions(ax, i, **obtain_predictions(mcmc_shared)) + +# Plot individual overdispersions +for i, ax in enumerate(axs[:, 2]): + plot_predictions(ax, i, **obtain_predictions(mcmc_indivi)) +# - diff --git a/src/covvfit/_padding.py b/src/covvfit/_padding.py index fac390b..081ed99 100644 --- a/src/covvfit/_padding.py +++ b/src/covvfit/_padding.py @@ -7,7 +7,13 @@ def _is_scalar(value) -> bool: - return not hasattr(value, "__len__") + try: + length = len(value) + if length != 0: + return False + return True + except TypeError: + return True def create_padded_array(