Skip to content

Commit

Permalink
Fix heatmap (#14)
Browse files Browse the repository at this point in the history
* Update plot_cpdb_heatmap.py

* Update test_plot_cpdb_heatmap.py

* Update plot_cpdb_heatmap.py

* Update plot_cpdb_heatmap.py

* Update tutorial.ipynb

* Update pyproject.toml

* Update tutorial.ipynb

* Update plot_cpdb_heatmap.py

* Update test_plot_cpdb_heatmap.py
  • Loading branch information
zktuong authored Dec 17, 2022
1 parent af1ff5a commit 692ad21
Show file tree
Hide file tree
Showing 4 changed files with 179 additions and 56 deletions.
202 changes: 151 additions & 51 deletions docs/notebooks/tutorial.ipynb

Large diffs are not rendered by default.

19 changes: 15 additions & 4 deletions ktplotspy/plot/plot_cpdb_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ def plot_cpdb_heatmap(
cmap: Optional[Union[str, ListedColormap]] = None,
title: str = "",
return_tables: bool = False,
symmetrical: bool = False,
**kwargs
) -> Union[sns.matrix.ClusterGrid, Dict]:
"""Plot cellphonedb results as total counts of interactions.
Expand Down Expand Up @@ -63,6 +64,8 @@ def plot_cpdb_heatmap(
Plot title.
return_tables : bool, optional
Whether to return the dataframes storing the interaction network.
symmetrical : bool, optional
Whether to return the sum of interactions as symmetrical heatmap.
**kwargs
Passed to seaborn.clustermap.
Expand Down Expand Up @@ -91,9 +94,17 @@ def plot_cpdb_heatmap(
count_mat = count_final.pivot_table(index="SOURCE", columns="TARGET", values="COUNT")
count_mat.columns.name, count_mat.index.name = None, None
count_mat[pd.isnull(count_mat)] = 0
all_sum = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions"]) + pd.DataFrame(
count_mat.apply(sum, axis=1), columns=["total_interactions"]
)
if symmetrical:
count_matx = np.triu(count_mat) + np.tril(count_mat.T) + np.tril(count_mat) + np.triu(count_mat.T)
count_matx = pd.DataFrame(count_matx)
count_matx.columns = count_mat.columns
count_matx.index = count_mat.index
count_mat = count_matx.copy()
all_sum = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions"])
else:
all_sum = pd.DataFrame(count_mat.apply(sum, axis=0), columns=["total_interactions"]) + pd.DataFrame(
count_mat.apply(sum, axis=1), columns=["total_interactions"]
)
if log1p_transform:
count_mat = np.log1p(count_mat)
if cmap is None:
Expand All @@ -114,5 +125,5 @@ def plot_cpdb_heatmap(
g.fig.suptitle(title)
return g
else:
out = {"count_network": count_mat, "interaction_count": all_sum}
out = {"count_network": count_mat, "interaction_count": all_sum, "interaction_edges": count_final}
return out
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
[tool.poetry]
name = "ktplotspy"
version = "0.1.5"
version = "0.1.6"
description = "Python library for plotting Cellphonedb results. Ported from ktplots R package."
authors = ["Kelvin Tuong <[email protected]>"]
license = "MIT"
Expand Down
12 changes: 12 additions & 0 deletions tests/test_plot_cpdb_heatmap.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,18 @@ def test_plot_cpdb_heatmap_log(mock_show, adata, pvals):
g


@patch("matplotlib.pyplot.show")
@pytest.mark.usefixtures("adata", "pvals")
def test_plot_cpdb_heatmap_sym(mock_show, adata, pvals):
g = plot_cpdb_heatmap(
adata=adata,
pvals=pvals,
celltype_key="celltype",
symmetrical=True,
)
g


@patch("matplotlib.pyplot.show")
@pytest.mark.usefixtures("adata", "pvals")
def test_plot_cpdb_heatmap_title(mock_show, adata, pvals):
Expand Down

0 comments on commit 692ad21

Please sign in to comment.