Skip to content

Commit

Permalink
minor changes
Browse files Browse the repository at this point in the history
  • Loading branch information
bio-la committed Apr 27, 2024
1 parent 3580f27 commit 87ead52
Show file tree
Hide file tree
Showing 8 changed files with 16 additions and 3 deletions.
5 changes: 4 additions & 1 deletion panpipes/python_scripts/batch_correct_scvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -149,14 +149,17 @@
)

scvi_model_args = {k: v for k, v in params['rna']['scvi']['model_args'].items() if v is not None}
print(scvi_model_args)
scvi_training_args = {k: v for k, v in params['rna']['scvi']['training_args'].items() if v is not None}
print(scvi_training_args)
scvi_training_plan = {k: v for k, v in params['rna']['scvi']['training_plan'].items() if v is not None}
print(scvi_training_plan)

L.info("Defining model")
vae = scvi.model.SCVI(rna, **scvi_model_args)
L.info("Running scVI")
vae.train(**scvi_training_args, plan_kwargs=scvi_training_plan)

L.info("Finished Training now saving model")
vae.save(os.path.join("batch_correction", "scvi_model"),
anndata=False)

Expand Down
5 changes: 5 additions & 0 deletions panpipes/python_scripts/batch_correct_totalvi.py
Original file line number Diff line number Diff line change
Expand Up @@ -219,6 +219,10 @@
else:
totalvi_training_plan = {k: v for k, v in params['multimodal']['totalvi']['training_plan'].items() if v is not None}

print(totalvi_model_args)
print(totalvi_training_args)
print(totalvi_training_plan)

L.info("Defining model")
vae = scvi.model.TOTALVI(rna, **totalvi_model_args)
L.info("Running totalVI")
Expand Down Expand Up @@ -252,6 +256,7 @@
mdata.obsm["X_totalVI"] = vae.get_latent_representation()

if batch_categories is not None:
L.debug(batch_categories)
if type(batch_categories) is not list:
batch_categories = [batch_categories]
normX, protein = vae.get_normalized_expression(
Expand Down
1 change: 1 addition & 0 deletions panpipes/python_scripts/batch_correct_wnn.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,7 @@
else:
dict_graph[x]["obsm"] = None

L.info(dict_graph)

if dict_graph["rna"]["obsm"] == "X_scvi":
dict_graph["rna"]["obsm"] = "X_scVI"
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/make_mudataspatial_from_csv.py
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ def check_dir_transform(infile_path, transform_file):

L.info("Resulting AnnData is:")
L.info(adata)
L.info("Creating MuData")
L.info("Creating MuData with .mod['spatial']")

mdata = MuData({"spatial": adata})

Expand Down
2 changes: 2 additions & 0 deletions panpipes/python_scripts/plot_custom_markers_umap.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def main(adata, mod, layer_choice, df, basis):
else:
# we have multimodal object
for mod in modalities:
print(mod)
df_sub = df[df['mod'] == mod]
mdata.update_obs()
try:
Expand All @@ -113,6 +114,7 @@ def main(adata, mod, layer_choice, df, basis):
bb = []
if len(bb) > 0 :
for basis, layer in product(bb, ll):
print(basis,layer)
main(adata=mdata[mod],
mod=mod,
layer_choice = layer,
Expand Down
1 change: 1 addition & 0 deletions panpipes/python_scripts/refmap_scvitools.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,6 +144,7 @@
max_epochs = 200
train_kwargs = {'weight_decay': 0.0}

print(train_kwargs)

if reference_architecture=="scvi":
L.info("Running scVI")
Expand Down
1 change: 1 addition & 0 deletions panpipes/python_scripts/run_filter_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,7 @@ def test_matching_df_ignore_cat(new_df, old_df):
# this will go through the modalities one at a time,
# then the categories max, min and bool
for mod in mdata.mod.keys():
L.info(mod)
if mod in filter_dict.keys():
for marg in filter_dict[mod].keys():
if marg == "obs":
Expand Down
2 changes: 1 addition & 1 deletion panpipes/python_scripts/run_scanpyQC_spatial.py
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@
sys.exit("The path of the cell cycle genes tsv file '%s' could not be found" % args.ccgenes)


# Aug 2023: we now need to update the mdata object to pick the calc proportion outputs made on
#TODO: we now need to update the mdata object to pick the calc proportion outputs made on
# spatial = mdata['spatial']

mdata.update()
Expand Down

0 comments on commit 87ead52

Please sign in to comment.