diff --git a/src/app.py b/src/app.py index 50bb9f2..5c1820c 100644 --- a/src/app.py +++ b/src/app.py @@ -3,10 +3,12 @@ import os import tempfile import anndata as ad +import pandas as pd from sliders import slider_ui, slider_server from distributions import distributions_server from plots import plots_server, plots_ui +from metadata import metadata_server, metadata_ui from helpers import calculate_qc_metrics @@ -19,7 +21,7 @@ app_ui = ui.page_navbar( ui.nav_panel("1. Upload", ui.input_file("file_input", label="Upload your file", accept=".h5ad")), - ui.nav_panel("2. Metadata", "Metadata"), + ui.nav_panel("2. Metadata", metadata_ui("metadata")), ui.nav_panel("3. Quality control", ui.layout_sidebar( ui.sidebar(slider_ui("sliders")), @@ -39,10 +41,12 @@ def server(input, output, session: Session): _adata_filtered: reactive.Value[ad.AnnData] = reactive.value(None) _file_name = reactive.value(None) _distributions = reactive.value({}) + _metadata = reactive.value(pd.DataFrame) distributions_server("distributions", _adata, _pretty_names, _distributions) slider_server("sliders", _adata, _adata_filtered, _pretty_names, _distributions) plots_server("plots", _adata_filtered, _pretty_names, _distributions) + metadata_server("metadata", _adata, _metadata) @reactive.effect def load_adata(): diff --git a/src/metadata.py b/src/metadata.py new file mode 100644 index 0000000..ee05ed7 --- /dev/null +++ b/src/metadata.py @@ -0,0 +1,48 @@ +from shiny import App, module, reactive, render, ui +import anndata as ad +import pandas as pd + +mandatory_columns = ["batch", "cell_type", "condition", "sex", "patient", "tissue"] + +@module.ui +def metadata_ui(): + return ui.div( + ui.output_ui("column_cards") + ) + +@module.server +def metadata_server(input, output, session, + _adata: reactive.Value[ad.AnnData], + _metadata: reactive.Value[pd.DataFrame] + ): + _additional_columns = reactive.value([]) + _all_columns = reactive.value(mandatory_columns) + + @reactive.effect + @reactive.event((input["add_column"])) + def add_column(): + column = input["column_name"].get() + if column and column not in _all_columns.get(): + _additional_columns.set(_additional_columns.get() + [column]) + ui.update_text("column_name", value="") + + @reactive.effect + def update_columns(): + _all_columns.set(mandatory_columns + _additional_columns.get()) + + @render.ui + def column_cards(): + add_card = ui.card( + ui.card_header("Add column"), + ui.input_text("column_name", "", placeholder="Column name"), + ui.input_action_button("add_column", "Add column") + ) + + return ui.layout_columns( + *[ui.card( + ui.card_header(column), + ui.p("This is a column") + ) + for column in _all_columns.get()], + add_card + )