Skip to content

Commit

Permalink
mu.pl.embedding: save the color palette in .uns, like scanpy
Browse files Browse the repository at this point in the history
  • Loading branch information
ilia-kats committed May 8, 2023
1 parent 2548595 commit 19f98d2
Show file tree
Hide file tree
Showing 2 changed files with 15 additions and 4 deletions.
5 changes: 4 additions & 1 deletion muon/_atac/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -164,7 +164,10 @@ def embedding(
adata=adata, keys=keys, average=average, func=func, use_raw=use_raw, layer=layer
)
ad = AnnData(x, obs=adata.obs, obsm=adata.obsm)
return sc.pl.embedding(ad, basis=basis, color=attr_names, **kwargs)
retval = sc.pl.embedding(ad, basis=basis, color=attr_names, **kwargs)
for aname in attr_names:
adata.uns[f"{aname}_colors"] = ad.uns[f"{aname}_colors"]
return retval

else:
return sc.pl.embedding(adata, basis=basis, use_raw=use_raw, layer=layer, **kwargs)
Expand Down
14 changes: 11 additions & 3 deletions muon/_core/plot.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,7 @@ def scatter(
if isinstance(color, str):
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
color_obs = pd.DataFrame({color: color_obs})
color = [color]
else:
# scanpy#311 / scanpy#1497 has to be fixed for this to work
color_obs = _get_values(data, color, use_raw=use_raw, layer=layers[2])
Expand All @@ -83,7 +84,11 @@ def scatter(
# Note that use_raw and layers are not provided to the plotting function
# as the corresponding values were fetched from individual modalities
# and are now stored in .obs
return sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs)
retval = sc.pl.scatter(ad, x=x, y=y, color=color, **kwargs)
if color is not None:
for col in color:
data.uns[f"{col}_colors"] = ad.uns[f"{col}_colors"]
return retval


#
Expand Down Expand Up @@ -170,7 +175,7 @@ def embedding(

# Some `color` has been provided
if isinstance(color, str):
keys = [color]
keys = color = [color]
elif isinstance(color, Iterable):
keys = color
else:
Expand Down Expand Up @@ -252,7 +257,10 @@ def embedding(
color = [mod_key_modifier[k] for k in keys]

ad = AnnData(obs=obs, obsm=adata.obsm, obsp=adata.obsp, uns=adata.uns)
return sc.pl.embedding(ad, basis=basis_mod, color=color, **kwargs)
retval = sc.pl.embedding(ad, basis=basis_mod, color=color, **kwargs)
for key, col in zip(keys, color):
adata.uns[f"{key}_colors"] = ad.uns[f"{col}_colors"]
return retval


def mofa(mdata: MuData, **kwargs) -> Union[Axes, List[Axes], None]:
Expand Down

0 comments on commit 19f98d2

Please sign in to comment.