Skip to content

Commit

Permalink
fix: remove empty subplots and add additional params
Browse files Browse the repository at this point in the history
  • Loading branch information
matq007 committed Sep 27, 2024
1 parent 0ebd757 commit 81349a7
Showing 1 changed file with 45 additions and 10 deletions.
55 changes: 45 additions & 10 deletions src/scanvi_explainer/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import numpy as np
import pandas as pd
import seaborn as sns
from matplotlib.figure import Figure
from scvi import REGISTRY_KEYS

from .scanvi_deep import SCANVIDeep
Expand All @@ -12,7 +13,11 @@ def feature_plot(
shap_values: np.ndarray,
subset: bool = False,
top_n: int = 10,
) -> None:
gene_symbols: None | str = None,
n_cols: int = 2,
figsize: tuple[int, int] = (20, 20),
return_fig: bool = False,
) -> Figure | None:
"""Prints feature contribution (absolute mean SHAP value) for each cell type (top 10).
Parameters
Expand All @@ -26,14 +31,33 @@ def feature_plot(
particual classifier.
When set to false, be generic and return contributing features even when testing set has
different cell types.
top_n: int
Subset for top N number of features
gene_symbols: None | str = None
Column name in `var` for gene symbols
n_cols: int
Number of columns in Figure
figsize : tuple[int, int]
Figure size, by default [20, 20]
return_fig : bool
Flag to return figure object, by default False
"""

if gene_symbols and gene_symbols not in explainer.adata.var.columns:
raise ValueError(
"Specified gene_symbol not present in the 'var' of model's adata!"
)

groupby = explainer.labels_key
classes = explainer.adata.obs[groupby].cat.categories
features = explainer.adata.var_names
features = (
explainer.adata.var[gene_symbols].values
if gene_symbols
else explainer.adata.var_names
)

nrows = classes.size // 2 + classes.size % 2
fig, ax = plt.subplots(nrows, 2, sharex=False, figsize=[20, 40])
nrows = round(classes.size / n_cols)
fig, ax = plt.subplots(nrows, n_cols, sharex=False, figsize=figsize)

for idx, ct in enumerate(classes):
shaps = pd.DataFrame(shap_values[idx], columns=features)
Expand All @@ -60,7 +84,7 @@ def feature_plot(
)

avg = pd.concat([positive, negative])
title = f"Average SHAP value importance for: {ct}"
title = f"Mean(|SHAP value|) importance for: {ct}"

else:
avg = (
Expand All @@ -72,16 +96,27 @@ def feature_plot(
.query("weight > 0")
.head(10)
)
title = f"Mean(|SHAP value|) average importance for: {ct}"
title = f"Mean(|SHAP value|) importance for: {ct}"

sns.barplot(
x="weight",
y="feature",
hue="contribution",
palette=["red", "blue"],
data=avg,
ax=ax[idx // 2, idx % 2],
ax=ax[idx // n_cols, idx % n_cols],
)
ax[idx // 2, idx % 2].set_title(title)
ax[idx // 2, idx % 2].legend(title="Contribution", loc="lower right")
fig.tight_layout()

ax[idx // n_cols, idx % n_cols].set_title(title)
ax[idx // n_cols, idx % n_cols].legend(title="Contribution", loc="lower right")

# clean axes which are empty
# from: https://stackoverflow.com/a/76269136
_ = [fig.delaxes(ax_) for ax_ in ax.flatten() if not ax_.has_data()]

fig.tight_layout()

if return_fig:
return fig

return None

0 comments on commit 81349a7

Please sign in to comment.