From 19f98d2ef80463b7e691b5b69030158838ec216b Mon Sep 17 00:00:00 2001 From: Ilia Kats Date: Mon, 8 May 2023 12:01:24 +0200 Subject: [PATCH] mu.pl.embedding: save the color palette in .uns, like scanpy --- muon/_atac/plot.py | 5 ++++- muon/_core/plot.py | 14 +++++++++++--- 2 files changed, 15 insertions(+), 4 deletions(-) diff --git a/muon/_atac/plot.py b/muon/_atac/plot.py index f0d4e9f..e52c19e 100644 --- a/muon/_atac/plot.py +++ b/muon/_atac/plot.py @@ -164,7 +164,10 @@ def embedding( adata=adata, keys=keys, average=average, func=func, use_raw=use_raw, layer=layer ) ad = AnnData(x, obs=adata.obs, obsm=adata.obsm) - return sc.pl.embedding(ad, basis=basis, color=attr_names, **kwargs) + retval = sc.pl.embedding(ad, basis=basis, color=attr_names, **kwargs) + for aname in attr_names: + adata.uns[f"{aname}_colors"] = ad.uns[f"{aname}_colors"] + return retval else: return sc.pl.embedding(adata, basis=basis, use_raw=use_raw, layer=layer, **kwargs) diff --git a/muon/_core/plot.py b/muon/_core/plot.py index f347b65..c6d413a 100644 --- a/muon/_core/plot.py +++ b/muon/_core/plot.py @@ -72,6 +72,7 @@ def scatter( if isinstance(color, str): color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) color_obs = pd.DataFrame({color: color_obs}) + color = [color] else: # scanpy#311 / scanpy#1497 has to be fixed for this to work color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2]) @@ -83,7 +84,11 @@ def scatter( # Note that use_raw and layers are not provided to the plotting function # as the corresponding values were fetched from individual modalities # and are now stored in .obs - return sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs) + retval = sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs) + if color is not None: + for col in color: + data.uns[f"{col}_colors"] = ad.uns[f"{col}_colors"] + return retval # @@ -170,7 +175,7 @@ def embedding( # Some `color` has been provided if isinstance(color, str): - keys = [color] + keys = color = [color] elif isinstance(color, Iterable): keys = color else: @@ -252,7 +257,10 @@ def embedding( color = [mod_key_modifier[k] for k in keys] ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp, uns=adata.uns) - return sc.pl.embedding(ad, basis=basis_mod, color=color, **kwargs) + retval = sc.pl.embedding(ad, basis=basis_mod, color=color, **kwargs) + for key, col in zip(keys, color): + adata.uns[f"{key}_colors"] = ad.uns[f"{col}_colors"] + return retval def mofa(mdata: MuData, **kwargs) -> Union[Axes, List[Axes], None]: