diff --git a/src/plots.py b/src/plots.py index b2068c7..22609be 100644 --- a/src/plots.py +++ b/src/plots.py @@ -13,8 +13,9 @@ def plots_ui(): ui.output_plot("plot_scatter") ), ui.card( - ui.card_header("Violin plots"), - ui.output_plot("plot_violines") + ui.card_header("Histograms"), + ui.output_ui("coloring_histograms"), + ui.output_plot("plot_histograms") ) ) @@ -56,13 +57,26 @@ def plot_scatter(): cbar.set_label(pretty_names[color_col]) return fig + + @output + @render.ui + def coloring_histograms(): + adata = _adata.get() + + if adata is None: + return + + categorical_obs = [None] + [column for column in adata.obs.select_dtypes(include=["object", "category"]).columns] + + return ui.input_select("histo_coloring", "Coloring", categorical_obs) @output @render.plot - def plot_violines(): + def plot_histograms(): adata = _adata.get() pretty_names = _pretty_names.get() distributions = _distributions.get() + coloring = input["histo_coloring"].get() if adata is None: return @@ -75,9 +89,17 @@ def plot_violines(): for i, (col, pretty_name) in enumerate(pretty_names.items()): ax = axes[i // n_cols, i % n_cols] if n_rows > 1 else axes[i] - sns.violinplot(x=adata.obs[col], ax=ax) + + kwargs = { + "x": col, + "ax": ax, + "bins": 50 + } + if coloring: + kwargs["hue"] = coloring + + sns.histplot(adata.obs, **kwargs) ax.set_xlabel(pretty_name) - ax.set_ylabel('Density') current_distribution = distributions[col]