diff --git a/src/jnotype/checks/__init__.py b/src/jnotype/checks/__init__.py index b9ec663..3562d2b 100644 --- a/src/jnotype/checks/__init__.py +++ b/src/jnotype/checks/__init__.py @@ -5,7 +5,13 @@ calculate_mcc, calculate_mutation_frequencies, calculate_number_of_mutations_histogram, + convert_genotypes_to_integers, + convert_integers_to_genotypes, + calculate_atoms_occurrence, + subsample_pytree, + simulate_summary_statistic, ) +from jnotype.checks._plots import rc_context, rcParams, plot_summary_statistic __all__ = [ "plot_histograms", @@ -13,4 +19,12 @@ "calculate_mcc", "calculate_mutation_frequencies", "calculate_number_of_mutations_histogram", + "rc_context", + "rcParams", + "plot_summary_statistic", + "convert_genotypes_to_integers", + "convert_integers_to_genotypes", + "calculate_atoms_occurrence", + "subsample_pytree", + "simulate_summary_statistic", ] diff --git a/src/jnotype/checks/_histograms.py b/src/jnotype/checks/_histograms.py index 6577570..d049d49 100644 --- a/src/jnotype/checks/_histograms.py +++ b/src/jnotype/checks/_histograms.py @@ -5,11 +5,15 @@ import matplotlib.pyplot as plt import numpy as np +import jax +import jax.numpy as jnp +from jaxtyping import Float, Array + def calculate_quantiles( - samples: np.ndarray, - quantiles: np.ndarray, -) -> np.ndarray: + samples: Float[Array, "n_samples dimension"], + quantiles: Float[Array, " n_quantiles"], +) -> Float[Array, "n_quantiles dimension"]: """Calculates quantiles. Args: @@ -20,15 +24,15 @@ def calculate_quantiles( quantile value for each dimension, shape (n_quantiles, dimension) """ - return np.quantile(samples, axis=0, q=quantiles) + return jnp.quantile(samples, axis=0, q=quantiles) def apply_histogram( - draws: np.ndarray, + draws: Float[Array, "n_datasets n_values"], bins: Union[int, Sequence[float], np.ndarray], density: bool, ) -> np.ndarray: - """Maps `np.histogram` over several vectors of values + """Maps `jnp.histogram` over several vectors of values to contruct several histograms. Args: @@ -39,11 +43,13 @@ def apply_histogram( Returns: histogram counts, shape (n_samples, bins) """ - _, bins = np.histogram(draws[0], bins=bins) + _, bins = jnp.histogram(draws[0], bins=bins, density=density) + + def f(sample): + """Auxiliary function used for jax.vmap""" + return jnp.histogram(sample, bins=bins, density=density)[0] - return np.asarray( - [np.histogram(vect, bins=bins, density=density)[0] for vect in draws] - ) + return jax.vmap(f)(draws) def plot_histograms( diff --git a/src/jnotype/checks/_plots.py b/src/jnotype/checks/_plots.py new file mode 100644 index 0000000..0b2f56e --- /dev/null +++ b/src/jnotype/checks/_plots.py @@ -0,0 +1,390 @@ +"""Plotting utilities for performing +graphical prior and posterior predictive checking.""" + +import copy +from contextlib import contextmanager +from typing import Literal, Optional, Sequence + +import matplotlib.pyplot as plt +import numpy as np +from jaxtyping import Float, Array + +from jnotype.checks._histograms import calculate_quantiles + + +def _wrap_array(y): + """A simple wrapper, working in a similar + manner as the Maybe functor: + + Optional[ArrayLike] -> Optional[NumPy Array] + """ + if y is not None: + return np.array(y) + else: + return None + + +# Global default plotting parameters, +# similarly as in Matplotlib +rcParams: dict = { + "color_data": "C0", + "color_simulations": "C1", + "quantiles": ((0.025, 0.975), (0.25, 0.75)), # tuple[tuple[float, float], ...] + "uncertainty_type": "quantiles", # Literal["quantiles", "trajectories", "none"] + "summary_type": "median", # Literal["mean", "median", "none"] + "trajectory_linewidth": 0.01, + "num_trajectories": 50, + "summary_linewidth": 1.0, + "summary_markersize": 3.0**2, + "summary_marker": ".", # Use "none" for no marker + "data_linewidth": 1.0, + "data_markersize": 3.0**2, + "data_marker": ".", +} + + +@contextmanager +def rc_context(rc: dict): + """Context manager maintained + in the same manner as in Matplotlib.""" + original = copy.deepcopy(rcParams) + try: + rcParams.update(rc) + yield + finally: + rcParams.update(original) + + +def _plot_quantiles( + ax: plt.Axes, + x_axis: Float[Array, " n_points"], + ys: Float[Array, "n_simulations n_points"], + color: Optional[str], + alpha: Optional[float], + quantiles: Optional[Sequence[tuple[float, float]]], +) -> None: + """Plots uncertainty in terms of quantiles + calculated separately for each coordinate.""" + # If quantiles are None, use the default value + if quantiles is None: + quantiles = rcParams["quantiles"] + # Ensure that quantiles is a list + quantiles = list(quantiles) # type: ignore + + # TODO(Pawel): Ensure that the quantiles are nested. + + # If alpha is not set, calculate a reasonable value + if alpha is None: + alpha = min(0.2, 1 / (1 + len(quantiles))) + + if color is None: + color = rcParams["color_simulations"] + + # Plot quantiles + for q_min, q_max in quantiles: + if q_min >= q_max: + raise ValueError(f"Quantile {q_min} >= {q_max}.") + y_quant = calculate_quantiles( + ys, + quantiles=np.array([q_min, q_max], dtype=float), # type: ignore + ) + ax.fill_between( + x_axis, + np.asarray(y_quant[0, :]), # type: ignore + np.asarray(y_quant[-1, :]), # type: ignore + color=color, + alpha=alpha, + edgecolor=None, + ) + + +def _plot_trajectories( + ax: plt.Axes, + x_axis: Float[Array, " n_points"], + y_simulated: Float[Array, "n_simulations n_points"], + color: Optional[str], + alpha: Optional[float], + num_trajectories: Optional[int], + trajectories_subsample_key: int, + trajectory_linewidth: Optional[float], +) -> None: + """Plots uncertainty in terms + of individual simulated samples.""" + if num_trajectories is None: + num_trajectories = rcParams["num_trajectories"] + assert num_trajectories is not None + + if num_trajectories <= 0: + raise ValueError("num_trajectories has to be at least 1") + if trajectory_linewidth is None: + trajectory_linewidth = rcParams["trajectory_linewidth"] + if color is None: + color = rcParams["color_simulations"] + if alpha is None: + alpha = min(0.1, 1 / (1 + num_trajectories)) + + num_simulations = y_simulated.shape[0] + + indices = np.arange(num_simulations) + if num_trajectories < num_simulations: + rng = np.random.default_rng(trajectories_subsample_key) + indices = rng.choice(indices, size=num_trajectories, replace=False) + + for index in indices: + ax.plot( + x_axis, + y_simulated[index], + linewidth=trajectory_linewidth, + color=color, + alpha=alpha, + ) + + +def _plot_main_summary( + ax: plt.Axes, + x_axis: Float[Array, " n_points"], + y_simulated: Float[Array, "n_simulations n_points"], + summary_type: Literal["mean", "median", "none", "default"], + summary_linewidth: Optional[float], + summary_markersize: Optional[float], + summary_marker: str, + color: Optional[str], +) -> None: + """Plots some "summary" such as mean or median of the + statistic from simulations.""" + if summary_type == "default": + summary_type = rcParams["summary_type"] + if summary_linewidth is None: + summary_linewidth = rcParams["summary_linewidth"] + if summary_markersize is None: + summary_markersize = rcParams["summary_markersize"] + if color is None: + color = rcParams["color_simulations"] + if summary_marker == "default": + summary_marker = rcParams["summary_marker"] + + y = None + if summary_type == "none": + return + elif summary_type == "median": + y = np.median(y_simulated, axis=0) + elif summary_type == "mean": + y = np.mean(y_simulated, axis=0) + else: + raise ValueError(f"Simulated summary type {summary_type} not known.") + + assert y is not None + assert y.shape == x_axis.shape + + # Now plot the summary statistic + ax.plot( + x_axis, + y, + c=color, + marker=summary_marker, + markersize=summary_markersize, + linewidth=summary_linewidth, + ) + + +def _plot_data( + ax: plt.Axes, + x_axis: Float[Array, " n_points"], + y: Float[Array, " n_points"], + data_linewidth: Optional[float], + data_markersize: Optional[float], + data_marker: str, + color: Optional[str], +) -> None: + """Plots the data.""" + if data_linewidth is None: + data_linewidth = rcParams["data_linewidth"] + if data_markersize is None: + data_markersize = rcParams["data_markersize"] + if color is None: + color = rcParams["color_data"] + if data_marker == "default": + data_marker = rcParams["data_marker"] + + if x_axis.shape != y.shape: + raise ValueError("x and y have different shapes") + + ax.plot( + x_axis, + y, + c=color, + marker=data_marker, + markersize=data_markersize, + linewidth=data_linewidth, + ) + + +def _plot_uncertainty( + ax: plt.Axes, + x_axis: Float[Array, " n_points"], + y_simulated: Float[Array, "n_simulations n_points"], + color_simulated: Optional[str], + uncertainty_type: Literal["default", "none", "trajectories", "quantiles"], + uncertainty_alpha: Optional[float], + quantiles: Optional[Sequence[tuple[float, float]]], + num_trajectories: Optional[int], + trajectory_linewidth: Optional[float], + trajectories_subsample_key: int, +) -> None: + """Plots uncertainty either in terms + of quantiles or trajectories.""" + # If we use the default settings, look for the right ones: + if uncertainty_type == "default": + uncertainty_type = rcParams["uncertainty_type"] + + # Now decide how to plot the uncertainty + if uncertainty_type == "none": + return # We don't have to plot anything + elif uncertainty_type == "quantiles": + _plot_quantiles( + ax=ax, + x_axis=x_axis, + ys=y_simulated, + color=color_simulated, + alpha=uncertainty_alpha, + quantiles=quantiles, + ) + elif uncertainty_type == "trajectories": + _plot_trajectories( + ax=ax, + x_axis=x_axis, + y_simulated=y_simulated, + color=color_simulated, + alpha=uncertainty_alpha, + num_trajectories=num_trajectories, + trajectories_subsample_key=trajectories_subsample_key, + trajectory_linewidth=trajectory_linewidth, + ) + else: + raise ValueError(f"Uncertainty {uncertainty_type} not recognized") + + +def plot_summary_statistic( + ax: plt.Axes, + x_axis: Optional[Float[Array, " n_points"]] = None, + y_data: Optional[Float[Array, " n_points"]] = None, + y_simulated: Optional[Float[Array, "n_simulations n_points"]] = None, + color_data: Optional[str] = None, + color_simulations: Optional[str] = None, + summary_type: Literal["none", "default", "mean", "median"] = "default", + summary_linewidth: Optional[float] = None, + summary_markersize: Optional[float] = None, + summary_marker: str = "default", + uncertainty_type: Literal[ + "none", "default", "trajectories", "quantiles" + ] = "default", + uncertainty_alpha: Optional[float] = None, + # Settings for plotting uncertainty + quantiles: Optional[Sequence[tuple[float, float]]] = None, + num_trajectories: Optional[int] = None, + trajectories_subsample_key: int = 42, + trajectory_linewidth: Optional[float] = None, + # Settings for plotting data + data_linewidth: Optional[float] = None, + data_markersize: Optional[float] = None, + data_marker: str = "default", + residuals: bool = False, + residuals_type: Literal[None, "mean", "median"] = None, +) -> None: + """Plots a summary statistic together with uncertainty. + + Args: + ax: axes on which the plot is done + x_axis: positions on the x_axis, shape (n_points,). + Leave as `None` to infer from the data. + y_data: summary statistic of the data, shape (n_points,), + where `n_points` is the dimensionality of the + summary statistic vector. Leave as `None` to not plot. + y_simulated: summary statistic corresponding to simulations, + shape (n_simulations, n_points) + """ + # Try to wrap all arraylike objects + x_axis, y_data, y_simulated = ( + _wrap_array(x_axis), # type: ignore + _wrap_array(y_data), # type: ignore + _wrap_array(y_simulated), # type: ignore + ) + + if x_axis is None: + if y_data is not None: + x_axis = np.arange(y_data.shape[0]) # type: ignore + elif y_simulated is not None: + x_axis = np.arange(y_simulated.shape[-1]) # type: ignore + else: + raise ValueError("No data to plot") + assert x_axis is not None + + n_points = x_axis.shape[0] + if y_data is not None: + if y_data.shape != (n_points,): + raise ValueError("Data has wrong shape.") + if y_simulated is not None: + if len(y_simulated.shape) != 2 or y_simulated.shape[-1] != n_points: + raise ValueError("Simulated data has wrong shape.") + + # Transform data + if residuals: + if y_simulated is None: + raise ValueError("For residual plot one has to provide simulated data.") + # Try to infer residuals_type from summary_type, if not provided + if residuals_type is None and summary_type in ["median", "mean"]: + residuals_type = summary_type # type: ignore + + if residuals_type is None: + raise ValueError("Residuals type could not be automatically inferred.") + elif residuals_type == "mean": + y_perfect = np.mean(y_simulated, axis=0) + elif residuals_type == "median": + y_perfect = np.mean(y_simulated, axis=0) + else: + raise ValueError(f"Residuals type {residuals_type} not known.") + + # Calculate the residuals + y_simulated = y_simulated - y_perfect[None, :] + if y_data is not None: + y_data = y_data - y_perfect + + # Plot simulated data + if y_simulated is not None: + # Start by plotting uncertainty + _plot_uncertainty( + ax=ax, + x_axis=x_axis, + y_simulated=y_simulated, + color_simulated=color_simulations, + uncertainty_type=uncertainty_type, + uncertainty_alpha=uncertainty_alpha, + quantiles=quantiles, + num_trajectories=num_trajectories, + trajectories_subsample_key=trajectories_subsample_key, + trajectory_linewidth=trajectory_linewidth, + ) + + # Now plot the main summary + _plot_main_summary( + ax=ax, + x_axis=x_axis, + y_simulated=y_simulated, + summary_type=summary_type, + summary_linewidth=summary_linewidth, + summary_markersize=summary_markersize, + summary_marker=summary_marker, + color=color_simulations, + ) + + if y_data is not None: + # Plot real data + _plot_data( + ax=ax, + x_axis=x_axis, + y=y_data, + data_linewidth=data_linewidth, + data_markersize=data_markersize, + data_marker=data_marker, + color=color_data, + ) diff --git a/src/jnotype/checks/_statistics.py b/src/jnotype/checks/_statistics.py index 15eb1d2..67f8994 100644 --- a/src/jnotype/checks/_statistics.py +++ b/src/jnotype/checks/_statistics.py @@ -1,6 +1,9 @@ """Convenient summary statistics, useful e.g. for posterior predictive checking.""" +from typing import Optional + +import jax import jax.numpy as jnp from jaxtyping import Array, Int, Float @@ -17,7 +20,7 @@ def calculate_mutation_frequencies(X: _DataSet) -> Float[Array, " n_genes"]: return jnp.mean(X, axis=0) -def calculate_number_of_mutations_histogram(X: _DataSet) -> Float[Array, " n_genes+1"]: +def calculate_number_of_mutations_histogram(X: _DataSet) -> Int[Array, " n_genes+1"]: """Creates an array counting the samples with a specific number of mutations. @@ -45,3 +48,133 @@ def calculate_mcc(X: _DataSet) -> Float[Array, "n_genes n_genes"]: (i.e., the mutation frequency is 0 or 1) """ return jnp.corrcoef(X, rowvar=False) + + +def _convert_binary_code_to_integer(x: Int[Array, " n_genes"]) -> Int[Array, " "]: + """A binary genotype represented as an array + is converted into the integer with corresponding + binary representation. + """ + n_genes = x.shape[0] + bit_positions = jnp.arange(n_genes - 1, -1, -1) + powers = jnp.power(2, bit_positions) + return jnp.sum(powers * x) + + +def convert_genotypes_to_integers(X: _DataSet) -> Int[Array, " n_samples"]: + """Each binary genotype is converted into the + integer with corresponding representation in + the binary numeral system. + """ + return jax.vmap(_convert_binary_code_to_integer)(X) + + +def convert_integers_to_genotypes( + integers: Int[Array, " n_samples"], n_genes: int +) -> _DataSet: + """Maps each integer to a binary genotype. + + Args: + Y (jnp.ndarray): Array of integers with shape (n_samples,). + n_genes (int): Number of genes (bits) in the genotype. + + Returns: + jnp.ndarray: Binary genotype array with shape (n_samples, n_genes). + """ + # Create an array of bit positions + bit_positions = jnp.arange(n_genes - 1, -1, -1) + # Right-shift and mask to get the bits + genotypes = (integers[:, None] >> bit_positions) & 1 + return genotypes + + +def calculate_atoms_occurrence(X: _DataSet) -> Int[Array, " 2**n_genes"]: + """For each unique genotype counts the number of matching samples. + + Note: + The returned array has exponentially large length, so that this + function should be avoided for large gene sets. + """ + indices = convert_genotypes_to_integers(X) + + n_genes = X.shape[1] + length = jnp.power(2, n_genes) + return jnp.bincount(indices, length=length) # type: ignore + + +def get_leading_axis_size(pytree) -> int: + """Infers the number of samples in a PyTree.""" + # Extract all leaf nodes from the PyTree + leaves = jax.tree_util.tree_leaves(pytree) + + if not leaves: + raise ValueError("The PyTree has no leaves.") + + # TODO(Pawel): Go through all the leaves and check + # if shapes agree + + # Assume the first leaf contains the leading axis + first_leaf = leaves[0] + + # Ensure the leaf has a shape attribute + if hasattr(first_leaf, "shape") and len(first_leaf.shape) > 0: + return first_leaf.shape[0] + else: + raise ValueError("The first leaf does not have a valid shape.") + + +def subsample_pytree( + key: jax.Array, + samples, + n_samples: Optional[int] = None, +): + """Subsamples a PyTree along the leading axis.""" + leading_size = get_leading_axis_size(samples) + + if n_samples is None: + n_samples = leading_size + + if n_samples > leading_size: + raise ValueError("n_samples cannot be larger than the leading axis size.") + + # Generate a permutation of indices and select the first n_samples + perm = jax.random.permutation(key, leading_size) + selected_indices = perm[:n_samples] + + def index_leaves(x): + """Function indexing each leaf""" + return x[selected_indices] + + # Apply the indexing function to all leaves + subsampled_pytree = jax.tree_util.tree_map(index_leaves, samples) + + return subsampled_pytree + + +def simulate_summary_statistic( + key: jax.Array, + simulator_fn, + statistic_fn, + samples, +): + """Simulates the summary statistics. + + Args: + key: JAX random key + simulator_fn: function with the signature + (RandomKey, Sample) -> DataSet + statistic_fn: function with the signature + Sample -> Statistic + samples: a PyTree with structure `Sample`, + which has a leading (0th) axis in each leaf + corresponding to the samples from the distribution + """ + n_samples = get_leading_axis_size(samples) + keys = jax.random.split(key, n_samples) + + def f(subkey, sample): + """Simulates a data set and calculates summary statistic.""" + y_sim = simulator_fn(subkey, sample) + return statistic_fn(y_sim) + + return jax.vmap(f)(keys, samples) diff --git a/tests/checks/test_statistics.py b/tests/checks/test_statistics.py new file mode 100644 index 0000000..365350e --- /dev/null +++ b/tests/checks/test_statistics.py @@ -0,0 +1,40 @@ +import jax +import jax.numpy as jnp +import jnotype.checks._statistics as st +import numpy.testing as npt + +import pytest + + +@pytest.mark.parametrize("n_genes", [3, 4]) +@pytest.mark.parametrize("n_samples", [2, 3, 5]) +def test_genotypes_integers_inverse1(n_genes: int, n_samples: int) -> None: + nums = jnp.arange(n_samples) + + genotypes = st.convert_integers_to_genotypes(nums, n_genes=n_genes) + nums_ = st.convert_genotypes_to_integers(genotypes) + npt.assert_allclose(nums, nums_) + + +@pytest.mark.parametrize("n_genes", [3, 4]) +@pytest.mark.parametrize("n_samples", [2, 3, 5]) +def test_genotypes_integers_inverse2(n_genes: int, n_samples: int) -> None: + key = jax.random.PRNGKey(n_genes * n_samples + 5) + genotypes = jax.random.bernoulli(key, p=0.5, shape=(n_samples, n_genes)) + + nums = st.convert_genotypes_to_integers(genotypes) + genotypes_ = st.convert_integers_to_genotypes(nums, n_genes=n_genes) + + npt.assert_allclose(genotypes, genotypes_) + + +@pytest.mark.parametrize("n_genes", [3, 4]) +def test_genotypes_integers_inverse3_boundary(n_genes: int, n_samples: int = 5) -> None: + genotypes1 = jnp.ones((n_samples, n_genes), dtype=int) + genotypes2 = jnp.zeros((n_samples, n_genes), dtype=int) + + nums1 = st.convert_genotypes_to_integers(genotypes1) + nums2 = st.convert_genotypes_to_integers(genotypes2) + + npt.assert_allclose(nums1, jnp.full(shape=(n_samples,), fill_value=2**n_genes - 1)) + npt.assert_allclose(nums2, jnp.zeros(shape=(n_samples,), dtype=int)) diff --git a/workflows/exclusivity/gbm_reproducibility.smk b/workflows/exclusivity/gbm_reproducibility.smk index 1139498..5264a24 100644 --- a/workflows/exclusivity/gbm_reproducibility.smk +++ b/workflows/exclusivity/gbm_reproducibility.smk @@ -27,8 +27,25 @@ from jnotype.checks import calculate_quantiles, calculate_mcc workdir: "generated/exclusivity/gbm_reproducibility" +# GENESETS = { +# "muex-0-3B": ['ABCC9', 'PIK3CA', 'RPL5', 'TRAT1'], +# } + + GENESETS = { + "misspecified": ["TP53", "CDKN2B", "NF1", "SPTA1"], + "muex-permutation-3A": ["EGFR", "GCSAML", "IDH1", "OTC"], "muex-0-3B": ['ABCC9', 'PIK3CA', 'RPL5', 'TRAT1'], + "muex-1-3C": ['PIK3C2G', 'PIK3CA', 'RPL5', 'TRAT1'], + "muex-2-3D": ['NF1', 'PIK3C2G', 'PIK3R1', 'TRAT1'], + "muex-3": ['ABCC9', 'PIK3C2G', 'PIK3CA', 'TRAT1'], + "muex-4": ['ABCC9', 'PIK3C2G', 'PIK3CA', 'SPTA1'], + "muex-5": ['ABCC9', 'KEL', 'PIK3C2G', 'PIK3CA'], + "muex-5": ['ABCC9', 'PIK3R1', 'RPL5', 'TRAT1'], + "muex-6": ['ABCC9', 'PIK3C2G', 'PIK3R1', 'TRAT1'], + "muex-7": ['PIK3C2G', 'PIK3R1', 'RPL5', 'TRAT1'], + "muex-8": ['ABCC9', 'PIK3C2G', 'PIK3CA', 'RPL5'], + "muex-9": ['ABCC9', 'KEL', 'PIK3C2G', 'RPL5'], } def get_prior_posterior_flag(name): @@ -100,7 +117,7 @@ rule plot_figure_comparison: stats = posterior["_component_independent"] color = "blue" - quantiles = calculate_quantiles(samples=stats, quantiles=[0.05, 0.25, 0.5, 0.75, 0.95]) + quantiles = calculate_quantiles(samples=stats, quantiles=np.array([0.05, 0.25, 0.5, 0.75, 0.95])) x_axis = jnp.arange(1, stats.shape[1] + 1) ax.set_ylim(0, quantiles.max() + 0.05)