diff --git a/docs/tutorials/model_training_and_eval.ipynb b/docs/tutorials/model_training_and_eval.ipynb index fc09a0a..14ebc35 100644 --- a/docs/tutorials/model_training_and_eval.ipynb +++ b/docs/tutorials/model_training_and_eval.ipynb @@ -2866,6 +2866,14 @@ }, "metadata": {}, "output_type": "display_data" + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "\u001b[1;34mwandb\u001b[0m: 🚀 View run \u001b[33mfinetuned_model\u001b[0m at: \u001b[34mhttps://wandb.ai/kemp/mouse_biccn/runs/it1js3u7\u001b[0m\n", + "\u001b[1;34mwandb\u001b[0m: Find logs at: \u001b[1;35mwandb/run-20241212_101919-it1js3u7/logs\u001b[0m\n" + ] } ], "source": [ diff --git a/src/crested/pl/hist/_locus_scoring.py b/src/crested/pl/hist/_locus_scoring.py index e7275eb..997d367 100644 --- a/src/crested/pl/hist/_locus_scoring.py +++ b/src/crested/pl/hist/_locus_scoring.py @@ -16,6 +16,9 @@ def locus_scoring( bigwig_values: np.ndarray | None = None, bigwig_midpoints: list[int] | None = None, filename: str | None = None, + grid: bool = True, + figsize: tuple[float, float] = (30,5), + highlight_positions: list[tuple[int, int]] | None = None, ): """ Plot the predictions as a line chart over the entire genomic input and optionally indicate the gene locus. @@ -44,6 +47,12 @@ def locus_scoring( A list of base pair positions corresponding to the bigwig_values. filename The filename to save the plot to. + grid + Add grid to plot. + figsize + Size of figure. + highlight_positions + A list of tuples specifying ranges to highlight on the plot. See Also -------- @@ -65,7 +74,7 @@ def locus_scoring( .. image:: ../../../../docs/_static/img/examples/hist_locus_scoring.png """ # Plotting predictions - plt.figure(figsize=(30, 10)) + plt.figure(figsize=figsize) # Top plot: Model predictions plt.subplot(2, 1, 1) @@ -78,13 +87,16 @@ def locus_scoring( label="Prediction Score", ) if gene_start is not None and gene_end is not None: - plt.axvspan(gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus") + plt.axvspan(gene_start, gene_end, color="red", alpha=0.2, label="Gene Locus") + if highlight_positions: + for start, end in highlight_positions: + plt.axvspan(start, end, color="green", alpha=0.3) plt.title(title) plt.xlabel("Genomic Position") plt.ylabel("Prediction Score") plt.ylim(bottom=0) plt.xticks(rotation=90) - plt.grid(True) + plt.grid(grid) plt.legend() if ylim: plt.ylim(ylim) @@ -101,13 +113,13 @@ def locus_scoring( ) if gene_start is not None and gene_end is not None: plt.axvspan( - gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus" + gene_start, gene_end, color="red", alpha=0.2, label="Gene Locus" ) plt.xlabel("Genomic Position") plt.ylabel("bigWig Values") plt.xticks(rotation=90) plt.ylim(bottom=0) - plt.grid(True) + plt.grid(grid) plt.legend() plt.tight_layout() diff --git a/src/crested/pl/scatter/_class_density.py b/src/crested/pl/scatter/_class_density.py index c17ad53..b6f1623 100644 --- a/src/crested/pl/scatter/_class_density.py +++ b/src/crested/pl/scatter/_class_density.py @@ -14,12 +14,13 @@ def class_density( adata: AnnData, - class_name: str, + class_name: str | None = None, model_names: list[str] | None = None, split: str | None = "test", log_transform: bool = False, exclude_zeros: bool = True, density_indication: bool = False, + alpha: float = 0.25, **kwargs, ) -> plt.Figure: """ @@ -30,7 +31,7 @@ def class_density( adata AnnData object containing the data in `X` and predictions in `layers`. class_name - Name of the class in `adata.obs_names`. + Name of the class in `adata.obs_names`. If None, plot is made for all the classes. model_names List of model names in `adata.layers`. If None, will create a plot per model in `adata.layers`. split @@ -38,9 +39,11 @@ def class_density( log_transform Whether to log-transform the data before plotting. Default is False. exclude_zeros - Whether to exclude zero values from the plot. Default is True. + Whether to exclude zero ground truth values from the plot. Default is True. density_indication Whether to indicate density in the scatter plot. Default is False. + alpha + Transparency of points in scatter plot. From 0 (transparent) to 1 (opaque). kwargs Additional arguments passed to :func:`~crested.pl.render_plot` to control the final plot output. Please see :func:`~crested.pl.render_plot` @@ -75,7 +78,7 @@ def _check_input_params(): "No split column found in anndata.var. Run `pp.train_val_test_split` first if 'split' is not None." ) - if class_name not in adata.obs_names: + if (class_name) and (class_name not in adata.obs_names): raise ValueError(f"Class {class_name} not found in adata.obs_names.") if split not in ["train", "val", "test", None]: raise ValueError("Split must be 'train', 'val', 'test', or None.") @@ -83,7 +86,7 @@ def _check_input_params(): _check_input_params() classes = list(adata.obs_names) - column_index = classes.index(class_name) + column_index = classes.index(class_name) if class_name else np.arange(0, len(classes)) if model_names is None: model_names = list(adata.layers.keys()) @@ -115,9 +118,14 @@ def _check_input_params(): n_models = len(predicted_values) - logger.info( - f"Plotting density scatter for class: {class_name}, models: {model_names}, split: {split}" - ) + if class_name: + logger.info( + f"Plotting density scatter for class: {class_name}, models: {model_names}, split: {split}" + ) + else: + logger.info( + f"Plotting density scatter for all targets and predictions, models: {model_names}, split: {split}" + ) fig, axes = plt.subplots(1, n_models, figsize=(8 * n_models, 8), sharey=True) if n_models == 1: @@ -130,10 +138,11 @@ def _check_input_params(): if density_indication: xy = np.vstack([x, y]) z = gaussian_kde(xy)(xy) - scatter = ax.scatter(x, y, c=z, s=50, edgecolor="k", alpha=0.25) + scatter = ax.scatter(x, y, c=z, s=50, edgecolor="k", alpha=alpha) + scatter.set_rasterized(True) # Rasterize only the scatter points plt.colorbar(scatter, ax=ax, label="Density") else: - scatter = ax.scatter(x, y, edgecolor="k", alpha=0.25) + scatter = ax.scatter(x, y, edgecolor="k", alpha=alpha) ax.annotate( f"Pearson: {pearson_corr:.2f}", @@ -164,6 +173,6 @@ def _check_input_params(): if "ylabel" not in kwargs: kwargs["ylabel"] = "Predictions" if "title" not in kwargs: - kwargs["title"] = f"{class_name}" + kwargs["title"] = f"{class_name}" if class_name else "Targets vs Predictions" return render_plot(fig, **kwargs) diff --git a/src/crested/tl/modisco/_tfmodisco.py b/src/crested/tl/modisco/_tfmodisco.py index cc57124..4952b18 100644 --- a/src/crested/tl/modisco/_tfmodisco.py +++ b/src/crested/tl/modisco/_tfmodisco.py @@ -545,7 +545,7 @@ def merge_patterns(pattern1: dict, pattern2: dict) -> dict: if ic_a > ic_b else pattern2["classes"][cell_type] ) - merged_classes[cell_type]['n_seqlets'] = max(n_seqlets_a, n_seqlets_b) # if patterns from the same class get merged, we keep the max seqlet count between the two of them since they are the same pattern + merged_classes[cell_type]['n_seqlets'] = n_seqlets_a + n_seqlets_b else: merged_classes[cell_type] = pattern1["classes"][cell_type]