diff --git a/examples/frequentist_fitting-JAX.py b/examples/frequentist_fitting-JAX.py deleted file mode 100644 index 09a8891..0000000 --- a/examples/frequentist_fitting-JAX.py +++ /dev/null @@ -1,606 +0,0 @@ -# --- -# jupyter: -# jupytext: -# text_representation: -# extension: .py -# format_name: percent -# format_version: '1.3' -# jupytext_version: 1.16.1 -# kernelspec: -# display_name: Python 3 (ipykernel) -# language: python -# name: python3 -# --- - -# %% -import covvfit._frequentist as freq -import covvfit._frequentist_jax as fj -import covvfit._preprocess_abundances as prec -import covvfit.plotting._timeseries as plot_ts -import jax -import jax.nn as nn -import jax.numpy as jnp -import matplotlib.pyplot as plt -import matplotlib.ticker as ticker -import numpy as np -import numpyro -import numpyro.distributions as dist -import pandas as pd -import pymc as pm -from numpyro.infer import MCMC, NUTS -from scipy.special import expit - -variants_full = [ - "B.1.1.7", - "B.1.351", - "P.1", - "B.1.617.2", - "BA.1", - "BA.2", - "BA.4", - "BA.5", - "BA.2.75", - "BQ.1.1", - "XBB.1.5", - "XBB.1.9", - "XBB.1.16", - "XBB.2.3", - "EG.5", - "BA.2.86", - "JN.1", -] - -variants = ["XBB.1.5", "XBB.1.9", "XBB.1.16", "XBB.2.3", "EG.5", "BA.2.86", "JN.1"] - -variants_other = [i for i in variants_full if i not in variants] - - -cities = [ - "Lugano (TI)", - "Zürich (ZH)", - "Chur (GR)", - "Altenrhein (SG)", - "Laupen (BE)", - "Genève (GE)", - "Basel (BS)", - "Porrentruy (JU)", - "Lausanne (VD)", - "Bern (BE)", - "Luzern (LU)", - "Solothurn (SO)", - "Neuchâtel (NE)", - "Schwyz (SZ)", -] - -colors_covsp = { - "B.1.1.7": "#D16666", - "B.1.351": "#FF6666", - "P.1": "#FFB3B3", - "B.1.617.1": "#66C265", - "B.1.617.2": "#66A366", - "BA.1": "#A366A3", - "BA.2": "#CFAFCF", - "BA.4": "#8a66ff", - "BA.5": "#585eff", - "BA.2.75": "#008fe0", - "BQ.1.1": "#ac00e0", - "XBB.1.9": "#bb6a33", - "XBB.1.5": "#ff5656", - "XBB.1.16": "#e99b30", - "XBB.2.3": "#f5e424", - "EG.5": "#b4e80b", - "BA.2.86": "#FF20E0", - "JN.1": "#00e9ff", # improv - "undetermined": "#969696", -} - - -# %% -DATA_PATH = "../data/robust_deconv2_noisy14.csv" - -data = prec.load_data(DATA_PATH) -variants2 = ["other"] + variants -data2 = prec.preprocess_df(data, cities, variants_full, date_min="2023-04-01") -data2["other"] = data2[variants_other].sum(axis=1) -data2[variants2] = data2[variants2].div(data2[variants2].sum(axis=1), axis=0) - -ts_lst, ys_lst = prec.make_data_list(data2, cities, variants2) -ts_lst, ys_lst2 = prec.make_data_list(data2, cities, variants) # TODO: To be removed? - -data = [fj.CityData(ts=ts, ys=ys.T, n=1) for ts, ys in zip(ts_lst, ys_lst)] - -n_cities = len(data) -n_variants = len(variants2) - - -# %% -loss_loglike = fj.construct_total_loss(data, average_loss=False) - - -def loss_prior(x: jnp.ndarray, mu: float = 0.15, sigma: float = 0.1) -> float: - # Shape and rate of gamma distribution - alpha = jnp.square(mu / sigma) - beta = mu / jnp.square(sigma) - - # Return -log_prob(growths) - g = fj.get_relative_growths(x, n_variants=n_variants) - return -jnp.sum(dist.Gamma(alpha, beta).log_prob(g)) - - -def optim_to_param(y): - midpoints = fj.get_relative_midpoints(y, n_variants=n_variants) - unconstrained_rates = fj.get_relative_growths(y, n_variants=n_variants) - rates = nn.softplus(unconstrained_rates) - return fj.construct_theta(relative_growths=rates, relative_midpoints=midpoints) - - -def loss_total(y): - x = optim_to_param(y) - return loss_loglike(x) + loss_prior(x) - - -theta0 = fj.construct_theta0(n_cities=n_cities, n_variants=n_variants) - -solution = fj.jax_multistart_minimize( - loss_total, - theta0, - random_seed=1, - n_starts=20, -) - -print(fj.get_relative_growths(optim_to_param(solution.best.x), n_variants=n_variants)) -print(solution.best.fun) - - -# %% -def model(mu: float = 0.15, sigma: float = 0.1): - midpoints = numpyro.sample( - "midpoints", - dist.Normal( - jnp.zeros_like(fj.get_relative_midpoints(theta0, n_variants=n_variants)), - 100, - ), - ) - - alpha = jnp.square(mu / sigma) - beta = mu / jnp.square(sigma) - growths = numpyro.sample( - "growths", - dist.Gamma( - alpha - * jnp.ones_like(fj.get_relative_growths(theta0, n_variants=n_variants)), - beta, - ), - ) - - # growths = numpyro.sample("growths", dist.TruncatedNormal(mu * jnp.ones_like(fj.get_relative_growths(theta0, n_variants=n_variants)), sigma, low=0.01, high=1.0)) - - x = fj.construct_theta(relative_growths=growths, relative_midpoints=midpoints) - numpyro.factor("loglikelihood", -loss_loglike(x)) - - -mcmc = MCMC(NUTS(model), num_warmup=1_000, num_samples=1_000, num_chains=4) -mcmc.run(jax.random.PRNGKey(0)) - -# %% -mcmc.print_summary() - -# %% - -rng = np.random.default_rng(42) - -mu = 0.15 -sigma = 0.1 - -alpha = jnp.square(mu / sigma) -beta = mu / jnp.square(sigma) - -samples = rng.gamma(alpha, 1 / beta, size=10_000) - -print(np.mean(samples)) -print(np.std(samples)) - -plt.hist(samples, bins=np.linspace(0, 0.5, 50), color="salmon", alpha=0.4, density=True) -plt.hist( - mcmc.get_samples()["growths"][:, -1], - bins=np.linspace(0, 0.5, 50), - color="darkblue", - alpha=0.4, - density=True, -) - - -np.mean(samples < 0.01) - -# %% -## This model takes into account the complement of the variants to be monitored, and sets its fitness to zero -## However, due to the pm.math.concatenate operation, we cannot use it for finding the hessian - - -def create_model_fixed2( - ts_lst, - ys_lst, - n=1.0, - coords={ - "cities": [], - "variants": [], - }, - n_pred=60, -): - """function to create a fixed effect model with varying intercepts and one rate vector""" - with pm.Model(coords=coords) as model: - midpoint_var = pm.Normal( - "midpoint", mu=0.0, sigma=300.0, dims=["cities", "variants"] - ) - rate_var = pm.Gamma("rate", mu=0.15, sigma=0.1, dims="variants") - - # Kaan's trick to avoid overflows - def softmax(x, rates, midpoints): - E = rates[:, None] * x + midpoints[:, None] - E_max = E.max(axis=0) - un_norm = pm.math.exp(E - E_max) - return un_norm / (pm.math.sum(un_norm, axis=0)) - - ys_smooth = [ - softmax( - ts_lst[i], - pm.math.concatenate([[0], rate_var]), - pm.math.concatenate([[0], midpoint_var[i, :]]), - ) - for i, city in enumerate(coords["cities"]) - ] - - # make Multinom/n likelihood - def log_likelihood(y, p, n): - # return n*pm.math.sum(y * pm.math.log(p), axis=0) + n*(1-pm.math.sum(y, axis=0))*pm.math.log(1-pm.math.sum(p, axis=0)) - return n * pm.math.sum(y * pm.math.log(p), axis=0) - - [ - pm.DensityDist( - f"ys_noisy_{city}", - ys_smooth[i], - n, - logp=log_likelihood, - observed=ys_lst[i], - ) - for i, city in enumerate(coords["cities"]) - ] - - return model - - -# %% -with create_model_fixed2( - ts_lst, - ys_lst, - coords={ - "cities": cities, - "variants": variants, - }, -): - model_map_fixed = pm.find_MAP(maxeval=50000, seed=12313) - - -# %% -print(model_map_fixed["rate"]) - -# %% -## This model takes into account the complement of the variants to be monitored, and sets its fitness to zero -## It has some numerical instabilities that make it not suitable for finding the MAP or MLE, but I use it for the Hessian - - -def create_model_fixed3( - ts_lst, - ys_lst, - n=1.0, - coords={ - "cities": [], - "variants": [], - }, - n_pred=60, -): - """function to create a fixed effect model with varying intercepts and one rate vector""" - with pm.Model(coords=coords) as model: - midpoint_var = pm.Normal( - "midpoint", mu=0.0, sigma=1500.0, dims=["cities", "variants"] - ) - rate_var = pm.Gamma("rate", mu=0.15, sigma=0.1, dims="variants") - - # Kaan's trick to avoid overflows - def softmax_1(x, rates, midpoints): - E = rates[:, None] * x + midpoints[:, None] - E_max = E.max(axis=0) - un_norm = pm.math.exp(E - E_max) - return un_norm / (pm.math.exp(-E_max) + pm.math.sum(un_norm, axis=0)) - - ys_smooth = [ - softmax_1(ts_lst[i], rate_var, midpoint_var[i, :]) - for i, city in enumerate(coords["cities"]) - ] - - # make Multinom/n likelihood - def log_likelihood(y, p, n): - return n * pm.math.sum(y * pm.math.log(p), axis=0) + n * ( - 1 - pm.math.sum(y, axis=0) - ) * pm.math.log(1 - pm.math.sum(p, axis=0)) - - # return n*pm.math.sum(y * pm.math.log(p), axis=0) - - [ - pm.DensityDist( - f"ys_noisy_{city}", - ys_smooth[i], - n, - logp=log_likelihood, - observed=ys_lst[i], - ) - for i, city in enumerate(coords["cities"]) - ] - - return model - - -# %% -with create_model_fixed3( - ts_lst, - ys_lst2, - coords={ - "cities": cities, - "variants": variants, - }, -): - model_hessian_fixed = pm.find_hessian(model_map_fixed) - -# %% -y_fit_lst = freq.fitted_values(ts_lst, model_map_fixed, cities) -ts_pred_lst, y_pred_lst = freq.pred_values( - [i.max() - 1 for i in ts_lst], model_map_fixed, cities, horizon=60 -) -pearson_r_lst, overdisp_list, overdisp_fixed = freq.compute_overdispersion( - ys_lst2, y_fit_lst, cities -) -( - fitness_diff, - fitness_diff_se, - fitness_diff_lower, - fitness_diff_upper, -) = freq.make_fitness_confints( - model_map_fixed["rate"], model_hessian_fixed, overdisp_fixed, g=7.0 -) - -# %% [markdown] -# ## Plot - -# %% -fig, axes_tot = plt.subplots(5, 3, figsize=(22, 15)) -# colors = default_cmap = plt.cm.get_cmap('tab10').colors -colors = [colors_covsp[var] for var in variants] -# axes=[axes_tot] -axes = axes_tot.flatten() -p_variants = len(variants) -p_params = model_hessian_fixed.shape[0] -model_hessian_inv = np.linalg.inv(model_hessian_fixed) - -for k, city in enumerate(cities): - ax = axes[k + 1] - y_fit = y_fit_lst[k] - ts = ts_lst[k] - ts_pred = ts_pred_lst[k] - y_pred = y_pred_lst[k] - ys = ys_lst2[k] - hessian_indices = np.concatenate( - [ - np.arange(p_variants) + k * p_variants, - np.arange(model_hessian_fixed.shape[0] - p_variants, p_params), - ] - ) - tmp_hessian = model_hessian_inv[hessian_indices, :][:, hessian_indices] - y_fit_logit = np.log(y_fit) - np.log(1 - y_fit) - logit_se = np.array( - [ - freq.project_se( - model_map_fixed["rate"], - model_map_fixed["midpoint"][k, :], - t, - tmp_hessian, - overdisp_list[k], - ) - for t in ts - ] - ).T - y_pred_logit = np.log(y_pred) - np.log(1 - y_pred) - logit_se_pred = np.array( - [ - freq.project_se( - model_map_fixed["rate"], - model_map_fixed["midpoint"][k, :], - t, - tmp_hessian, - overdisp_list[k], - ) - for t in ts_pred - ] - ).T - - for i, variant in enumerate(variants): - # grid - ax.vlines( - x=( - pd.date_range(start="2021-11-01", end="2024-02-01", freq="MS") - - pd.to_datetime("2023-01-01") - ).days, - ymin=-0.05, - ymax=1.05, - color="grey", - alpha=0.02, - ) - ax.hlines( - y=[0, 0.25, 0.5, 0.75, 1], - xmin=(pd.to_datetime("2021-10-10") - pd.to_datetime("2023-01-01")).days, - xmax=(pd.to_datetime("2024-02-20") - pd.to_datetime("2023-01-01")).days, - color="grey", - alpha=0.02, - ) - ax.fill_between(x=ts_pred, y1=0, y2=1, color="grey", alpha=0.01) - - # plot fitted - sorted_indices = np.argsort(ts) - ax.plot( - ts[sorted_indices], - y_fit[i, :][sorted_indices], - color=colors[i], - label="fit", - ) - # plot pred - ax.plot(ts_pred, y_pred[i, :], color=colors[i], linestyle="--", label="predict") - # plot confints - ax.fill_between( - ts[sorted_indices], - expit( - y_fit_logit[i, :][sorted_indices] - - 1.96 * logit_se[i, :][sorted_indices] - ), - expit( - y_fit_logit[i, :][sorted_indices] - + 1.96 * logit_se[i, :][sorted_indices] - ), - color=colors[i], - alpha=0.2, - label="Confidence band", - ) - ax.fill_between( - ts_pred, - expit(y_pred_logit[i, :] - 1.96 * logit_se_pred[i, :]), - expit(y_pred_logit[i, :] + 1.96 * logit_se_pred[i, :]), - color=colors[i], - alpha=0.2, - label="Confidence band", - ) - - # plot empirical - ax.scatter(ts, ys[i, :], label="observed", alpha=0.5, color=colors[i], s=4) - - ax.set_ylim((-0.05, 1.05)) - ax.set_xticks( - ( - pd.date_range(start="2021-11-01", end="2023-12-01", freq="MS") - - pd.to_datetime("2023-01-01") - ).days - ) - date_formatter = ticker.FuncFormatter(plot_ts.num_to_date) - ax.xaxis.set_major_formatter(date_formatter) - tick_positions = [0, 0.5, 1] - tick_labels = ["0%", "50%", "100%"] - ax.set_yticks(tick_positions) - ax.set_yticklabels(tick_labels) - ax.set_ylabel("relative abundances") - ax.set_xlim( - ( - pd.to_datetime(["2023-03-15", "2024-01-05"]) - - pd.to_datetime("2023-01-01") - ).days - ) - ax.set_title(f"{city}") - -## Plot estimates - -ax = axes[0] - -( - fitness_diff, - fitness_diff_se, - fitness_diff_lower, - fitness_diff_upper, -) = freq.make_fitness_confints( - model_map_fixed["rate"], model_hessian_fixed, overdisp_fixed, g=7.0 -) - -fitness_diff = fitness_diff * 100 -fitness_diff_lower = fitness_diff_lower * 100 -fitness_diff_upper = fitness_diff_upper * 100 - -# Get the indices for the upper triangle, starting at the diagonal (k=0) -upper_triangle_indices = np.triu_indices_from(fitness_diff, k=0) - -# Assign np.nan to the upper triangle including the diagonal -fitness_diff[upper_triangle_indices] = np.nan -fitness_diff_lower[upper_triangle_indices] = np.nan -fitness_diff_upper[upper_triangle_indices] = np.nan - -fitness_diff[:-2, :] = np.nan -fitness_diff_lower[:-2, :] = np.nan -fitness_diff_upper[:-2, :] = np.nan - -# Calculate the error (distance from the point to the error bar limit) -error = np.array( - [ - fitness_diff - fitness_diff_lower, # Lower error - fitness_diff_upper - fitness_diff, # Upper error - ] -) - -# Define the width of the offset -offset_width = 0.1 -num_sets = fitness_diff.shape[0] -# num_sets = 2 -mid = (num_sets - 1) / 2 - -# grid -ax.vlines( - x=np.arange(len(variants) - 1), - ymin=np.nanmin(fitness_diff_lower), - ymax=np.nanmax(fitness_diff_upper), - color="grey", - alpha=0.2, -) -ax.hlines( - y=np.arange(-25, 126, step=25), - xmin=-0.5, - xmax=len(variants) - 2 + 0.5, - color="grey", - alpha=0.2, -) - -# Plot each set of points with error bars -for i, y_vals in enumerate(fitness_diff): - # Calculate offset for each set - offset = (i - mid) * offset_width - # Create an array of x positions for this set - # x_positions = np.arange(len(variants)) + offset - x_positions = np.arange(len(variants)) + offset - 0.25 - # We need to transpose the error array to match the shape of y_vals - ax.errorbar( - x_positions, - y_vals, - yerr=error[:, i, :], - fmt="o", - label=variants[i], - color=colors_covsp[variants[i]], - ) - -# Set the x-ticks to be at the middle of the groups of points -ax.set_xticks(np.arange(len(variants) - 1)) -ax.set_xticklabels(variants[:-1]) - -# Add some labels and a legend -ax.set_xlabel("Variants") -ax.set_ylabel("% weekly growth advantage") -ax.set_title("growth advantages") - - -fig.tight_layout() -fig.legend( - handles=plot_ts.make_legend(colors, variants), - loc="lower center", - ncol=9, - bbox_to_anchor=(0.5, -0.04), - frameon=False, -) - - -plt.savefig("growth_rates20231108.pdf", bbox_inches="tight") - -plt.show() - - -# %% diff --git a/examples/frequentist_notebook_jax.py b/examples/frequentist_notebook_jax.py index 260ca68..f94604c 100644 --- a/examples/frequentist_notebook_jax.py +++ b/examples/frequentist_notebook_jax.py @@ -1,16 +1,16 @@ # --- # jupyter: # jupytext: -# formats: ipynb,py +# formats: ipynb,py:light # text_representation: # extension: .py # format_name: light # format_version: '1.5' -# jupytext_version: 1.16.4 +# jupytext_version: 1.16.1 # kernelspec: -# display_name: jax +# display_name: Python 3 (ipykernel) # language: python -# name: jax +# name: python3 # --- # + @@ -18,7 +18,6 @@ import jax.numpy as jnp import pandas as pd -# import pymc as pm import numpy as np @@ -34,17 +33,30 @@ import covvfit._frequentist as freq import covvfit._preprocess_abundances as prec import covvfit.plotting._timeseries as plot_ts -import covvfit._frequentist_jax as fj + +from covvfit import quasimultinomial as qm # - # # Load and preprocess data -DATA_PATH = "../../LolliPop/lollipop_test_noisy/deconvolved.csv" +# + +DATA_PATH = "../new_data/deconvolved.csv" +VAR_DATES_PATH = "../new_data/var_dates.yaml" + data = pd.read_csv(DATA_PATH, sep="\t") data.head() + +# Load the YAML file +with open(VAR_DATES_PATH, "r") as file: + var_dates_data = yaml.safe_load(file) + +# Access the var_dates data +var_dates = var_dates_data["var_dates"] +# - + data_wide = data.pivot_table( index=["date", "location"], columns="variant", values="proportion", fill_value=0 ).reset_index() @@ -57,20 +69,6 @@ max_date = pd.to_datetime(data_wide["time"]).max() delta_time = pd.Timedelta(days=240) start_date = max_date - delta_time -# - - - - -# + -# Path to the YAML file -var_dates_yaml = "../../LolliPop/lollipop_test_noisy/var_dates.yaml" - -# Load the YAML file -with open(var_dates_yaml, "r") as file: - var_dates_data = yaml.safe_load(file) - -# Access the var_dates data -var_dates = var_dates_data["var_dates"] # + @@ -119,35 +117,37 @@ def match_date(start_date): # + # %%time -data = [] -for t, y in zip(ts_lst_scaled, ys_lst): - data.append(fj.CityData(ts=t, ys=y.T, n=1)) - -# no priors -loss = fj.construct_total_loss(data) -# initial parameters -theta0 = fj.construct_theta0(n_cities=len(cities), n_variants=len(variants2)) -# -solution = fj.jax_multistart_minimize( - loss, - theta0, - n_starts=10 +# no priors +loss = qm.construct_total_loss( + ys=[ + y.T for y in ys_lst + ], # Recall that the input should be (n_timepoints, n_variants) + ts=ts_lst_scaled, + average_loss=False, # Do not average the loss over the data points, so that the covariance matrix shrinks with more and more data added ) +# initial parameters +theta0 = qm.construct_theta0(n_cities=len(cities), n_variants=len(variants2)) + +# Run the optimization routine +solution = qm.jax_multistart_minimize(loss, theta0, n_starts=10) # - # ## Make fitted values and confidence intervals # + ## compute fitted values -y_fit_lst = fj.fitted_values(ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2)) +y_fit_lst = qm.fitted_values( + ts_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) +) ## compute covariance matrix -covariance = fj.get_covariance(loss, solution.x) +covariance = qm.get_covariance(loss, solution.x) ## compute overdispersion +# TODO(Pawel, David): Port compute_overdispersion to `qm` pearson_r_lst, overdisp_list, overdisp_fixed = freq.compute_overdispersion( ys_lst2, y_fit_lst, cities ) @@ -156,18 +156,24 @@ def match_date(start_date): covariance_scaled = overdisp_fixed * covariance ## compute standard errors and confidence intervals of the estimates -standard_errors_estimates = fj.get_standard_errors(covariance_scaled) -confints_estimates = fj.get_confidence_intervals(solution.x, standard_errors_estimates) +standard_errors_estimates = qm.get_standard_errors(covariance_scaled) +confints_estimates = qm.get_confidence_intervals(solution.x, standard_errors_estimates) ## compute confidence intervals of the fitted values on the logit scale and back transform -y_fit_lst_confint = fj.get_confidence_bands_logit(solution.x, len(variants2), ts_lst_scaled, covariance_scaled) +y_fit_lst_confint = qm.get_confidence_bands_logit( + solution.x, len(variants2), ts_lst_scaled, covariance_scaled +) ## compute predicted values and confidence bands horizon = 60 ts_pred_lst = [jnp.arange(horizon + 1) + tt.max() for tt in ts_lst] ts_pred_lst_scaled = [(x - t_min) / (t_max - t_min) for x in ts_pred_lst] -y_pred_lst = fj.fitted_values(ts_pred_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2)) -y_pred_lst_confint = fj.get_confidence_bands_logit(solution.x, len(variants2), ts_pred_lst_scaled, covariance_scaled) +y_pred_lst = qm.fitted_values( + ts_pred_lst_scaled, theta=solution.x, cities=cities, n_variants=len(variants2) +) +y_pred_lst_confint = qm.get_confidence_bands_logit( + solution.x, len(variants2), ts_pred_lst_scaled, covariance_scaled +) # - @@ -192,10 +198,10 @@ def match_date(start_date): # plot fitted and predicted values plot_fit(ax, ts_lst[i], y_fit_lst[i], variants, colors) plot_fit(ax, ts_pred_lst[i], y_pred_lst[i], variants, colors, linetype="--") - + # # plot 1-fitted and predicted values plot_complement(ax, ts_lst[i], y_fit_lst[i], variants) -# plot_complement(ax, ts_pred_lst[i], y_pred_lst[i], variants, linetype="--") + # plot_complement(ax, ts_pred_lst[i], y_pred_lst[i], variants, linetype="--") # plot raw deconvolved values plot_data(ax, ts_lst[i], ys_lst2[i], variants, colors) # make confidence bands and plot them @@ -205,21 +211,22 @@ def match_date(start_date): ts_lst[i], {"lower": conf_bands[0], "upper": conf_bands[1]}, variants, - colors + colors, ) - + pred_bands = y_pred_lst_confint[i] plot_confidence_bands( ax, ts_pred_lst[i], {"lower": pred_bands[0], "upper": pred_bands[1]}, variants, - colors + colors, ) # format axes and title def format_date(x, pos): return plot_ts.num_to_date(x, date_min=start_date) + date_formatter = ticker.FuncFormatter(format_date) ax.xaxis.set_major_formatter(date_formatter) tick_positions = [0, 0.5, 1] @@ -231,3 +238,4 @@ def format_date(x, pos): fig.tight_layout() fig.show() +# - diff --git a/src/covvfit/__init__.py b/src/covvfit/__init__.py index 18c9b01..4e0d375 100644 --- a/src/covvfit/__init__.py +++ b/src/covvfit/__init__.py @@ -2,16 +2,19 @@ try: import covvfit._frequentist as freq + + warnings.warn("The `freq` submodule is deprecated.") except Exception as e: warnings.warn( f"It is not possible to use `freq` subpackage due to missing dependencies. Exception raised: {e}" ) freq = None + try: - import covvfit._frequentist_jax as freq_jax + import covvfit._quasimultinomial as quasimultinomial except Exception as e: warnings.warn( - f"It is not possible to use `freq_jax` subpackage due to missing dependencies. Exception raised: {e}" + f"It is not possible to use `quasimultinomial` subpackage due to missing dependencies. Exception raised: {e}" ) freq_jax = None @@ -37,7 +40,7 @@ "load_data", "VERSION", "freq", - "freq_jax", + "quasimultinomial", "plot", "simulation", ] diff --git a/src/covvfit/_frequentist_jax.py b/src/covvfit/_quasimultinomial.py similarity index 50% rename from src/covvfit/_frequentist_jax.py rename to src/covvfit/_quasimultinomial.py index 4266cb0..64dde13 100644 --- a/src/covvfit/_frequentist_jax.py +++ b/src/covvfit/_quasimultinomial.py @@ -1,11 +1,13 @@ """Frequentist fitting functions powered by JAX.""" import dataclasses -from typing import Callable, List, NamedTuple, Optional, Sequence +from typing import Callable, NamedTuple import jax import jax.numpy as jnp import numpy as np +import numpyro +import numpyro.distributions as distrib from jaxtyping import Array, Float from scipy import optimize @@ -63,61 +65,12 @@ def loss( return -jnp.sum(n * y * logp, axis=-1) -class CityData(NamedTuple): - ts: Float[Array, " timepoints"] - ys: Float[Array, "timepoints variants"] - n: _Float - - _ThetaType = Float[Array, "(cities+1)*(variants-1)"] -def add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]: - return jnp.concatenate([jnp.zeros_like(vec)[0:1], vec]) - - -def construct_total_loss( - cities: Sequence[CityData], - average_loss: bool = False, -) -> Callable[[_ThetaType], _Float]: - cities = tuple(cities) - n_variants = cities[0].ys.shape[-1] - for city in cities: - assert ( - city.ys.shape[-1] == n_variants - ), "All cities must have the same number of variants" - - if average_loss: # trick for numerical stability (average loss doesnt blow up) - n_points_total = 1.0 * sum(city.ts.shape[0] for city in cities) - else: - n_points_total = 1.0 - - def total_loss(theta: _ThetaType) -> _Float: - rel_growths = get_relative_growths(theta, n_variants=n_variants) - rel_midpoints = get_relative_midpoints(theta, n_variants=n_variants) - - growths = add_first_variant(rel_growths) - return ( - jnp.sum( - jnp.asarray( - [ - loss( - y=city.ys, - n=city.n, - logp=calculate_logps( - ts=city.ts, - midpoints=add_first_variant(midp), - growths=growths, - ), - ).sum() - for midp, city in zip(rel_midpoints, cities) - ] - ) - ) - / n_points_total - ) - - return total_loss +def _add_first_variant(vec: Float[Array, " variants-1"]) -> Float[Array, " variants"]: + """Prepends 0 to the beginning of the vector.""" + return jnp.concatenate((jnp.zeros(1, dtype=vec.dtype), vec)) def construct_theta( @@ -160,7 +113,7 @@ def convert(confidence: float) -> float: def get_covariance( loss_fn: Callable[[_ThetaType], _Float], theta: _ThetaType, -) -> Float[Array, "(n_params n_params)"]: +) -> Float[Array, "n_params n_params"]: """Calculates the covariance matrix of the parameters. Args: @@ -169,6 +122,12 @@ def get_covariance( Returns: The covariance matrix, which is the inverse of the Hessian matrix. + + + Note: + `loss_fn` should *not* be averaged over the data points: otherwise, + the covariance matrix won't shrink even when a very big data + set is used """ hessian_matrix = jax.hessian(loss_fn)(theta) covariance_matrix = jnp.linalg.inv(hessian_matrix) @@ -178,7 +137,7 @@ def get_covariance( def get_standard_errors( covariance: Float[Array, "n_inputs n_inputs"], - jacobian: Optional[Float[Array, "*output_shape n_inputs"]] = None, + jacobian: Float[Array, "*output_shape n_inputs"] | None = None, ) -> Float[Array, " *output_shape"]: """Delta method to calculate standard errors of a function from `n_inputs` to `output_shape`. @@ -234,7 +193,7 @@ def get_confidence_intervals( def fitted_values( - times: List[Float[Array, " timepoints"]], + times: list[Float[Array, " timepoints"]], theta: _ThetaType, cities: list, n_variants: int, @@ -290,10 +249,10 @@ def logit_predictions_with_fixed_args( def get_confidence_bands_logit( solution_x: Float[Array, " (cities+1)*(variants-1)"], variants_count: int, - ts_lst_scaled: List[Float[Array, " timepoints"]], + ts_lst_scaled: list[Float[Array, " timepoints"]], covariance_scaled: Float[Array, "n_params n_params"], confidence_level: float = 0.95, -) -> List[tuple]: +) -> list[tuple]: """Computes confidence intervals for logit predictions using the Delta method, back-transforms them to the linear scale @@ -353,7 +312,7 @@ def get_relative_advantages(theta, n_variants: int): # over the 0th variant rel_growths = get_relative_growths(theta, n_variants=n_variants) - growths = jnp.concatenate((jnp.zeros(1, dtype=rel_growths.dtype), rel_growths)) + growths = _add_first_variant(rel_growths) diffs = growths[None, :] - growths[:, None] return diffs @@ -362,10 +321,10 @@ def get_softmax_predictions( theta: _ThetaType, n_variants: int, city_index: int, ts: Float[Array, " timepoints"] ) -> Float[Array, "timepoints variants"]: rel_growths = get_relative_growths(theta, n_variants=n_variants) - growths = add_first_variant(rel_growths) + growths = _add_first_variant(rel_growths) rel_midpoints = get_relative_midpoints(theta, n_variants=n_variants) - midpoints = add_first_variant(rel_midpoints[city_index]) + midpoints = _add_first_variant(rel_midpoints[city_index]) y_linear = calculate_linear( ts=ts, @@ -446,3 +405,305 @@ def loss_grad_fun(theta): fun=solutions[optimal_index].fun, runs=solutions, ) + + +class _ProblemData(NamedTuple): + """Internal representation of the data used + to efficiently construct the quasilikelihood. + + Attrs: + ts: array of shape (cities, timepoints) + which is padded with 0 for days where + there is no measurement for a particular city + ys: array of shape (cities, timepoints, variants) + which is padded with the vector (1/variants, ..., 1/variants) + for timepoints where there is no measurement for a particular city + mask: array of shape (cities, timepoints) with 0 when there is + no measurement for a particular city and 1 otherwise + n_quasimul: quasimultinomial number of trials for each city + overdispersion: overdispersion factor for each city + """ + + n_cities: int + n_variants: int + ts: Float[Array, "cities timepoints"] + ys: Float[Array, "cities timepoints variants"] + mask: Float[Array, "cities timepoints"] + n_quasimul: Float[Array, " cities"] + overdispersion: Float[Array, " cities"] + + +def _validate_and_pad( + ys: list[jax.Array], + ts: list[jax.Array], + ns_quasimul: Float[Array, " cities"] | list[float] | float = 1.0, + overdispersion: Float[Array, " cities"] | list[float] | float = 1.0, +) -> _ProblemData: + """Validation function, parsing the input provided in + the format convenient for the user to the internal + representation compatible with JAX.""" + # Get the number of cities + n_cities = len(ys) + if len(ts) != n_cities: + raise ValueError(f"Number of cities not consistent: {len(ys)} != {len(ts)}.") + + # Create arrays representing `n` and `overdispersion` + if hasattr(ns_quasimul, "__len__"): + if len(ns_quasimul) != n_cities: + raise ValueError( + f"Provided `ns_quasimul` has length {len(ns_quasimul)} rather than {n_cities}." + ) + if hasattr(overdispersion, "__len__"): + if len(overdispersion) != n_cities: + raise ValueError( + f"Provided `overdispersion` has length {len(overdispersion)} rather than {n_cities}." + ) + + out_n = jnp.asarray(ns_quasimul) * jnp.ones(n_cities, dtype=float) + out_overdispersion = jnp.asarray(overdispersion) * jnp.ones_like(out_n) + + # Get the number of variants + n_variants = ys[0].shape[-1] + for i, y in enumerate(ys): + if y.ndim != 2: + raise ValueError(f"City {i} has {y.ndim} dimension, rather than 2.") + if y.shape[-1] != n_variants: + raise ValueError( + f"City {i} has {y.shape[-1]} variants rather than {n_variants}." + ) + + # Ensure that the number of timepoints is consistent + max_timepoints = 0 + for i, (t, y) in enumerate(zip(ts, ys)): + if t.ndim != 1: + raise ValueError( + f"City {i} has time axis with dimension {t.ndim}, rather than 1." + ) + if t.shape[0] != y.shape[0]: + raise ValueError( + f"City {i} has timepoints mismatch: {t.shape[0]} != {y.shape[0]}." + ) + + max_timepoints = t.shape[0] + + # Now create the arrays representing the data + out_ts = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros + out_mask = jnp.zeros((n_cities, max_timepoints)) # Pad with zeros + out_ys = jnp.full( + shape=(n_cities, max_timepoints, n_variants), fill_value=1.0 / n_variants + ) # Pad with constant vectors + + for i, (t, y) in enumerate(zip(ts, ys)): + n_timepoints = t.shape[0] + + out_ts = out_ts.at[i, :n_timepoints].set(t) + out_ys = out_ys.at[i, :n_timepoints, :].set(y) + out_mask = out_mask.at[i, :n_timepoints].set(1) + + return _ProblemData( + n_cities=n_cities, + n_variants=n_variants, + ts=out_ts, + ys=out_ys, + mask=out_mask, + n_quasimul=out_n, + overdispersion=out_overdispersion, + ) + + +def _quasiloglikelihood_single_city( + relative_growths: Float[Array, " variants-1"], + relative_offsets: Float[Array, " variants-1"], + ts: Float[Array, " timepoints"], + ys: Float[Array, "timepoints variants"], + mask: Float[Array, " timepoints"], + n_quasimul: float, + overdispersion: float, +) -> float: + weight = n_quasimul / overdispersion + logps = calculate_logps( + ts=ts, + midpoints=_add_first_variant(relative_offsets), + growths=_add_first_variant(relative_growths), + ) + return jnp.sum(mask[:, None] * weight * ys * logps) + + +_RelativeGrowthsAndOffsetsFunction = Callable[ + [Float[Array, " variants-1"], Float[Array, " variants-1"]], _Float +] + + +def _generate_quasiloglikelihood_function( + data: _ProblemData, +) -> _RelativeGrowthsAndOffsetsFunction: + """Creates the quasilikelihood function with signature: + + def quasiloglikelihood( + relative_growths: array of shape (variants-1,) + relative_offsets: array of shape (cities, variants-1) + ) -> float + """ + + def quasiloglikelihood( + relative_growths: Float[Array, " variants-1"], + relative_offsets: Float[Array, "cities variants-1"], + ) -> _Float: + # Broadcast the array, to use the same relative growths + # for each city + _new_shape = (data.n_cities, relative_growths.shape[-1]) + tiled_growths = jnp.broadcast_to(relative_growths, _new_shape) + + logps = jax.vmap(_quasiloglikelihood_single_city)( + relative_growths=tiled_growths, + relative_offsets=relative_offsets, + ts=data.ts, + ys=data.ys, + mask=data.mask, + n_quasimul=data.n_quasimul, + overdispersion=data.overdispersion, + ) + return jnp.sum(logps) + + return quasiloglikelihood + + +def construct_model( + ys: list[jax.Array], + ts: list[jax.Array], + ns: Float[Array, " cities"] | list[float] | float = 1.0, + overdispersion: Float[Array, " cities"] | list[float] | float = 1.0, + sigma_growth: float = 10.0, + sigma_offset: float = 1000.0, +) -> Callable: + """Builds a NumPyro model suitable for sampling from the quasiposterior. + + Args: + ys: list of variant proportions for each city. + The ith entry should be an array + of shape (n_timepoints[i], n_variants) + ts: list of timepoints. The ith entry should be an array + of shape (n_timepoints[i],) + Note: `ts` should be appropriately normalized + ns: controls the overdispersion of each city by means of + quasimultinomial sample size + overdispersion: controls the overdispersion factor as in the + quasilikelihood approach + sigma_growth: controls the standard deviation of the prior + on the relative growths + sigma_offset: controls the standard deviation of the prior + on the relative offsets + + Note: + The "loglikelihood" is effectively rescaled by `ns/overdispersion` + factor. Hence, using both `ns` and `overdispersion` should generally + be avoided. + """ + data = _validate_and_pad( + ys=ys, + ts=ts, + ns_quasimul=ns, + overdispersion=overdispersion, + ) + + quasi_ll_fn = _generate_quasiloglikelihood_function(data) + + def model(): + # Sample growth differences. Note that we sample from the N(0, 1) + # distribution and then resample, for numerical stability + _scaled_rel_growths = numpyro.sample( + "_scaled_relative_growths", + distrib.Normal().expand((data.n_variants - 1,)), + ) + rel_growths = numpyro.deterministic( + "relative_growths", + sigma_growth * _scaled_rel_growths, + ) + + # Sample offsets. We use scaling the same scaling trick as above + _scaled_rel_offsets = numpyro.sample( + "_scaled_relative_offsets", + distrib.Normal().expand((data.n_cities, data.n_variants - 1)), + ) + rel_offsets = numpyro.deterministic( + "relative_offsets", + _scaled_rel_offsets * sigma_offset, + ) + + numpyro.factor("quasiloglikelihood", quasi_ll_fn(rel_growths, rel_offsets)) + + return model + + +def construct_total_loss( + ys: list[jax.Array], + ts: list[jax.Array], + ns: list[float] | float = 1.0, + overdispersion: list[float] | float = 1.0, + accept_theta: bool = True, + average_loss: bool = False, +) -> Callable[[_ThetaType], _Float] | _RelativeGrowthsAndOffsetsFunction: + """Constructs the loss function, suitable e.g., for optimization. + + Args: + ys: list of variant proportions for each city. + The ith entry should be an array + of shape (n_timepoints[i], n_variants) + ts: list of timepoints. The ith entry should be an array + of shape (n_timepoints[i],) + Note: `ts` should be appropriately normalized + ns: controls the overdispersion of each city by means of + quasimultinomial sample size + overdispersion: controls the overdispersion factor as in the + quasilikelihood approach + accept_theta: whether the returned loss function should accept the + `theta` vector (suitable for optimization) + or should be parameterized by the relative growths + and relative offsets, as in + ``` + def loss( + relative_growths: array of shape (variants-1,) + relative_offsets: array of shape (cities, variants-1) + ) -> float + ``` + average_loss: whether the loss should be divided by the + total number of points. By default it is false, as the loss + is used to calculate confidence intervals. Setting it to true + can improve the convergence of the optimization procedure + + Note: + The "loglikelihood" is effectively rescaled by `ns/overdispersion` + factor. Hence, using both `ns` and `overdispersion` should generally + be avoided. + """ + + data: _ProblemData = _validate_and_pad( + ys=ys, + ts=ts, + ns_quasimul=ns, + overdispersion=overdispersion, + ) + + if average_loss: + scaling = jnp.sum(data.mask, dtype=float) + else: + scaling = 1.0 + + # Get the quasilikelihood function + quasi_ll_fn = _generate_quasiloglikelihood_function(data) + + # Define the loss function parameterized + # with relative growths and offsets + def _loss_fn(relative_growths, relative_offsets): + return -quasi_ll_fn(relative_growths, relative_offsets) / scaling + + # Define the loss function in terms of the theta variable + def _loss_fn_theta(theta): + rel_growths = get_relative_growths(theta, n_variants=data.n_variants) + rel_offsets = get_relative_midpoints(theta, n_variants=data.n_variants) + return _loss_fn(relative_growths=rel_growths, relative_offsets=rel_offsets) + + if accept_theta: + return _loss_fn_theta + else: + return _loss_fn diff --git a/tests/test_frequentist_jax.py b/tests/test_quasimultinomial.py similarity index 77% rename from tests/test_frequentist_jax.py rename to tests/test_quasimultinomial.py index 04c0ac3..edb4f9d 100644 --- a/tests/test_frequentist_jax.py +++ b/tests/test_quasimultinomial.py @@ -1,4 +1,4 @@ -import covvfit._frequentist_jax as fj +import covvfit._quasimultinomial as qm import jax import numpy.testing as npt import pytest @@ -13,17 +13,17 @@ def test_parameter_conversions_1(seed: int, n_cities: int, n_variants: int) -> N growth_rel = jax.random.uniform(key1, shape=(n_variants - 1,)) midpoint_rel = jax.random.uniform(key2, shape=(n_cities, n_variants - 1)) - theta = fj.construct_theta( + theta = qm.construct_theta( relative_growths=growth_rel, relative_midpoints=midpoint_rel, ) npt.assert_allclose( - fj.get_relative_growths(theta, n_variants=n_variants), + qm.get_relative_growths(theta, n_variants=n_variants), growth_rel, ) npt.assert_allclose( - fj.get_relative_midpoints(theta, n_variants=n_variants), + qm.get_relative_midpoints(theta, n_variants=n_variants), midpoint_rel, ) @@ -36,11 +36,11 @@ def test_parameter_conversions_2(seed: int, n_cities: int, n_variants: int) -> N jax.random.PRNGKey(seed), shape=(n_cities * (n_variants - 1) + n_variants - 1,) ) - growth_rel = fj.get_relative_growths(theta, n_variants=n_variants) - midpoint_rel = fj.get_relative_midpoints(theta, n_variants=n_variants) + growth_rel = qm.get_relative_growths(theta, n_variants=n_variants) + midpoint_rel = qm.get_relative_midpoints(theta, n_variants=n_variants) npt.assert_allclose( - fj.construct_theta( + qm.construct_theta( relative_growths=growth_rel, relative_midpoints=midpoint_rel, ), diff --git a/workflows/bootstrap_simulation.smk b/workflows/bootstrap_simulation.smk index 25f4489..ee96d90 100644 --- a/workflows/bootstrap_simulation.smk +++ b/workflows/bootstrap_simulation.smk @@ -195,7 +195,8 @@ rule fit_to_bootstrapped_sample: boostrap_index = int(wildcards.bootstrap) key = jax.random.PRNGKey(boostrap_index) - data = [] + ts_all = [] + ys_all = [] for city_index in range(settings.n_cities): subkey = jax.random.fold_in(key, city_index) @@ -211,10 +212,12 @@ rule fit_to_bootstrapped_sample: ts_obs = ts_obs[index] ys_obs = ys_obs[index, :] - data.append(fj.CityData(ts=ts_obs, ys=ys_obs, n=1)) # Note that we use n = 1, as we do not know the true value + ts_all.append(ts_obs) + ys_all.append(ys_obs) + # Note that we don't use any priors: we optimize just the (quasi-)likelihood - loss = fj.construct_total_loss(data) + loss = fj.construct_total_loss(ys=ys_all, ts=ts_all, accept_theta=True) theta0 = fj.construct_theta0(n_cities=settings.n_cities, n_variants=settings.n_variants) solution = fj.jax_multistart_minimize(