Skip to content

Commit

Permalink
Implement metadata unification process
Browse files Browse the repository at this point in the history
  • Loading branch information
nictru committed Mar 9, 2024
1 parent 28ca539 commit deb408f
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 18 deletions.
18 changes: 15 additions & 3 deletions src/app.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,13 +38,14 @@

def server(input, output, session: Session):
_adata: reactive.Value[ad.AnnData] = reactive.value(None)
_adata_meta: reactive.Value[ad.AnnData] = reactive.value(None)
_adata_filtered: reactive.Value[ad.AnnData] = reactive.value(None)
_file_name = reactive.value(None)
_distributions = reactive.value({})
_metadata = reactive.value(pd.DataFrame)

distributions_server("distributions", _adata, _pretty_names, _distributions)
slider_server("sliders", _adata, _adata_filtered, _pretty_names, _distributions)
distributions_server("distributions", _adata_meta, _pretty_names, _distributions)
slider_server("sliders", _adata_meta, _adata_filtered, _pretty_names, _distributions)
plots_server("plots", _adata_filtered, _pretty_names, _distributions)
metadata_server("metadata", _adata, _metadata)

Expand All @@ -60,9 +61,20 @@ def load_adata():
used_file = file[0]
_file_name.set(used_file["name"])
adata = sc.read_h5ad(used_file["datapath"])
calculate_qc_metrics(adata)
_adata.set(adata)

@reactive.effect
def update_adata_meta():
adata = _adata.get()
metadata = _metadata.get()
if adata is None or metadata is None:
return
adata_meta = adata.copy()
adata_meta.obs = metadata.copy()
calculate_qc_metrics(adata_meta)
print("Updating adata_meta")
_adata_meta.set(adata_meta)

@render.download(
filename = lambda: _file_name.get().replace(".h5ad", "_filtered.h5ad"),
)
Expand Down
57 changes: 47 additions & 10 deletions src/metadata.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@

@module.ui
def metadata_ui():
return ui.accordion(
ui.accordion_panel("Columns",
return ui.navset_tab(
ui.nav_panel("Columns",
ui.output_ui("column_cards")),
ui.accordion_panel("Preview")
ui.nav_panel("Preview", ui.output_data_frame("metadata_table"))
)

@module.server
Expand Down Expand Up @@ -44,30 +44,41 @@ def update_columns():

def column_card(column: str):
select_accession = f"select_type_{column}"
type_string = input[select_accession].get() if select_accession in input else None
type_string = input[select_accession].get() if select_accession in input else "Constant value"
adata = _adata.get()

if adata is None:
print("No adata")
return None

if type_string == "Concat existing columns":
interface = ui.input_selectize(f"select_{column}", "Columns", _input_columns.get(), multiple=True)
select_concat_accession = f"select_concat_{column}"
included_columns = list(input[select_concat_accession].get()) if select_concat_accession in input else []
interface = ui.input_selectize(select_concat_accession, "Columns", _input_columns.get(), selected=included_columns, multiple=True)
elif type_string == "Map existing":
mapcol_accession = f"select_mapcol_{column}"
colselect = ui.input_select(mapcol_accession, "Column", _input_columns.get())
available_columns = _input_columns.get()
colselect = ui.input_select(mapcol_accession, "Column", available_columns)

series = adata.obs[input[mapcol_accession].get()] if mapcol_accession in input else None
series = adata.obs[input[mapcol_accession].get() if mapcol_accession in input else available_columns[0]]
unique_values = series.unique() if series is not None else []

interface = ui.div(
colselect,
ui.accordion(
ui.accordion_panel("Mapping",
*([ui.input_text(f"mapping_{column}_{col}", col, placeholder="New value") for col in series.unique()] if series is not None else [])
*[ui.input_text(f"mapping_{column}_{value}",
value,
input[f"mapping_{column}_{value}"].get() if f"mapping_{column}_{value}" in input else "",
placeholder="Unknown") for value in unique_values]
))
)
elif type_string == "Constant value":
value_accession = f"select_constant_{column}"
existing_value = input[value_accession].get() if value_accession in input else ""
interface = ui.input_text(value_accession, "Value", existing_value, placeholder="Unknown")
else:
interface = ui.input_text(f"input_{column}", "Value", placeholder="Unknown")
raise ValueError(f"Unknown type {type_string}")


return ui.card(
Expand All @@ -94,4 +105,30 @@ def column_cards():
add_card
)


@reactive.effect
def update_metadata():
adata = _adata.get()
if adata is None:
return
metadata = pd.DataFrame(index=adata.obs.index)
for column in _all_columns.get():
select_accession = f"select_type_{column}"
type_string = input[select_accession].get() if select_accession in input else None
if type_string == "Concat existing columns":
included_columns = list(input[f"select_concat_{column}"].get()) if f"select_concat_{column}" in input else []
metadata[column] = adata.obs[included_columns].astype(str).apply(lambda x: "_".join(x), axis=1) if included_columns else "Unknown"
elif type_string == "Map existing":
mapcol_accession = f"select_mapcol_{column}"
available_columns = _input_columns.get()
series = adata.obs[input[mapcol_accession].get() if mapcol_accession in input else available_columns[0]]

mapping = {value: input[f"mapping_{column}_{value}"].get() for value in series.unique()}
metadata[column] = series.map(mapping)
else:
constant_accession = f"select_constant_{column}"
metadata[column] = input[constant_accession].get() if constant_accession in input else "Unknown"
_metadata.set(metadata)

@render.data_frame
def metadata_table():
return _metadata.get()
6 changes: 1 addition & 5 deletions src/sliders.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,6 @@ def slider_server(input, output, session,
):
_adata_sample = reactive.value(None)
_prev_mads = reactive.value({})

@output
@render.text
def out():
return f"Click count is {_adata.get()}\n\nPretty names are {_pretty_names.get()}\n\nDistributions are {_distributions.get()}"

@output
@render.ui
Expand Down Expand Up @@ -60,6 +55,7 @@ def slider_filters():

for col, pretty_name in pretty_names.items():
if distributions[col]['min'] == distributions[col]['max']:
print(f"Skipping {col}")
continue
else:
mads = ui.input_slider(f"{col}_mads", "MADs", 0.25, 10, 2, step=0.25)
Expand Down

0 comments on commit deb408f

Please sign in to comment.