Skip to content

Commit

Permalink
plotting updates
Browse files Browse the repository at this point in the history
  • Loading branch information
nkempynck committed Dec 13, 2024
1 parent 4cced02 commit 5023544
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 17 deletions.
8 changes: 8 additions & 0 deletions docs/tutorials/model_training_and_eval.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -2866,6 +2866,14 @@
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"\u001b[1;34mwandb\u001b[0m: 🚀 View run \u001b[33mfinetuned_model\u001b[0m at: \u001b[34mhttps://wandb.ai/kemp/mouse_biccn/runs/it1js3u7\u001b[0m\n",
"\u001b[1;34mwandb\u001b[0m: Find logs at: \u001b[1;35mwandb/run-20241212_101919-it1js3u7/logs\u001b[0m\n"
]
}
],
"source": [
Expand Down
22 changes: 17 additions & 5 deletions src/crested/pl/hist/_locus_scoring.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,9 @@ def locus_scoring(
bigwig_values: np.ndarray | None = None,
bigwig_midpoints: list[int] | None = None,
filename: str | None = None,
grid: bool = True,
figsize: tuple[float, float] = (30,5),
highlight_positions: list[tuple[int, int]] | None = None,
):
"""
Plot the predictions as a line chart over the entire genomic input and optionally indicate the gene locus.
Expand Down Expand Up @@ -44,6 +47,12 @@ def locus_scoring(
A list of base pair positions corresponding to the bigwig_values.
filename
The filename to save the plot to.
grid
Add grid to plot.
figsize
Size of figure.
highlight_positions
A list of tuples specifying ranges to highlight on the plot.
See Also
--------
Expand All @@ -65,7 +74,7 @@ def locus_scoring(
.. image:: ../../../../docs/_static/img/examples/hist_locus_scoring.png
"""
# Plotting predictions
plt.figure(figsize=(30, 10))
plt.figure(figsize=figsize)

# Top plot: Model predictions
plt.subplot(2, 1, 1)
Expand All @@ -78,13 +87,16 @@ def locus_scoring(
label="Prediction Score",
)
if gene_start is not None and gene_end is not None:
plt.axvspan(gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus")
plt.axvspan(gene_start, gene_end, color="red", alpha=0.2, label="Gene Locus")
if highlight_positions:
for start, end in highlight_positions:
plt.axvspan(start, end, color="green", alpha=0.3)
plt.title(title)
plt.xlabel("Genomic Position")
plt.ylabel("Prediction Score")
plt.ylim(bottom=0)
plt.xticks(rotation=90)
plt.grid(True)
plt.grid(grid)
plt.legend()
if ylim:
plt.ylim(ylim)
Expand All @@ -101,13 +113,13 @@ def locus_scoring(
)
if gene_start is not None and gene_end is not None:
plt.axvspan(
gene_start, gene_end, color="red", alpha=0.3, label="Gene Locus"
gene_start, gene_end, color="red", alpha=0.2, label="Gene Locus"
)
plt.xlabel("Genomic Position")
plt.ylabel("bigWig Values")
plt.xticks(rotation=90)
plt.ylim(bottom=0)
plt.grid(True)
plt.grid(grid)
plt.legend()

plt.tight_layout()
Expand Down
31 changes: 20 additions & 11 deletions src/crested/pl/scatter/_class_density.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,12 +14,13 @@

def class_density(
adata: AnnData,
class_name: str,
class_name: str | None = None,
model_names: list[str] | None = None,
split: str | None = "test",
log_transform: bool = False,
exclude_zeros: bool = True,
density_indication: bool = False,
alpha: float = 0.25,
**kwargs,
) -> plt.Figure:
"""
Expand All @@ -30,17 +31,19 @@ def class_density(
adata
AnnData object containing the data in `X` and predictions in `layers`.
class_name
Name of the class in `adata.obs_names`.
Name of the class in `adata.obs_names`. If None, plot is made for all the classes.
model_names
List of model names in `adata.layers`. If None, will create a plot per model in `adata.layers`.
split
'train', 'val', 'test' subset or None. If None, will use all targets. If not None, expects a "split" column in adata.var.
log_transform
Whether to log-transform the data before plotting. Default is False.
exclude_zeros
Whether to exclude zero values from the plot. Default is True.
Whether to exclude zero ground truth values from the plot. Default is True.
density_indication
Whether to indicate density in the scatter plot. Default is False.
alpha
Transparency of points in scatter plot. From 0 (transparent) to 1 (opaque).
kwargs
Additional arguments passed to :func:`~crested.pl.render_plot` to
control the final plot output. Please see :func:`~crested.pl.render_plot`
Expand Down Expand Up @@ -75,15 +78,15 @@ def _check_input_params():
"No split column found in anndata.var. Run `pp.train_val_test_split` first if 'split' is not None."
)

if class_name not in adata.obs_names:
if (class_name) and (class_name not in adata.obs_names):
raise ValueError(f"Class {class_name} not found in adata.obs_names.")
if split not in ["train", "val", "test", None]:
raise ValueError("Split must be 'train', 'val', 'test', or None.")

_check_input_params()

classes = list(adata.obs_names)
column_index = classes.index(class_name)
column_index = classes.index(class_name) if class_name else np.arange(0, len(classes))
if model_names is None:
model_names = list(adata.layers.keys())

Expand Down Expand Up @@ -115,9 +118,14 @@ def _check_input_params():

n_models = len(predicted_values)

logger.info(
f"Plotting density scatter for class: {class_name}, models: {model_names}, split: {split}"
)
if class_name:
logger.info(
f"Plotting density scatter for class: {class_name}, models: {model_names}, split: {split}"
)
else:
logger.info(
f"Plotting density scatter for all targets and predictions, models: {model_names}, split: {split}"
)

fig, axes = plt.subplots(1, n_models, figsize=(8 * n_models, 8), sharey=True)
if n_models == 1:
Expand All @@ -130,10 +138,11 @@ def _check_input_params():
if density_indication:
xy = np.vstack([x, y])
z = gaussian_kde(xy)(xy)
scatter = ax.scatter(x, y, c=z, s=50, edgecolor="k", alpha=0.25)
scatter = ax.scatter(x, y, c=z, s=50, edgecolor="k", alpha=alpha)
scatter.set_rasterized(True) # Rasterize only the scatter points
plt.colorbar(scatter, ax=ax, label="Density")
else:
scatter = ax.scatter(x, y, edgecolor="k", alpha=0.25)
scatter = ax.scatter(x, y, edgecolor="k", alpha=alpha)

ax.annotate(
f"Pearson: {pearson_corr:.2f}",
Expand Down Expand Up @@ -164,6 +173,6 @@ def _check_input_params():
if "ylabel" not in kwargs:
kwargs["ylabel"] = "Predictions"
if "title" not in kwargs:
kwargs["title"] = f"{class_name}"
kwargs["title"] = f"{class_name}" if class_name else "Targets vs Predictions"

return render_plot(fig, **kwargs)
2 changes: 1 addition & 1 deletion src/crested/tl/modisco/_tfmodisco.py
Original file line number Diff line number Diff line change
Expand Up @@ -545,7 +545,7 @@ def merge_patterns(pattern1: dict, pattern2: dict) -> dict:
if ic_a > ic_b
else pattern2["classes"][cell_type]
)
merged_classes[cell_type]['n_seqlets'] = max(n_seqlets_a, n_seqlets_b) # if patterns from the same class get merged, we keep the max seqlet count between the two of them since they are the same pattern
merged_classes[cell_type]['n_seqlets'] = n_seqlets_a + n_seqlets_b
else:
merged_classes[cell_type] = pattern1["classes"][cell_type]

Expand Down

0 comments on commit 5023544

Please sign in to comment.