diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py index 79a5858b..7769cf1d 100644 --- a/src/state/_cli/_tx/_predict.py +++ b/src/state/_cli/_tx/_predict.py @@ -233,8 +233,8 @@ def load_config(cfg_path: str) -> dict: all_pert_names = [] all_celltypes = [] all_gem_groups = [] - all_pert_barcodes = [] - all_ctrl_barcodes = [] + # all_pert_barcodes = [] + # all_ctrl_barcodes = [] with torch.no_grad(): for batch_idx, batch in enumerate(tqdm(test_loader, desc="Predicting", unit="batch")): @@ -251,12 +251,12 @@ def load_config(cfg_path: str) -> dict: else: all_pert_names.append(batch_preds["pert_name"]) - if isinstance(batch_preds["pert_cell_barcode"], list): - all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) - all_ctrl_barcodes.extend(batch_preds["ctrl_cell_barcode"]) - else: - all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) - all_ctrl_barcodes.append(batch_preds["ctrl_cell_barcode"]) + # if isinstance(batch_preds["pert_cell_barcode"], list): + # all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) + # all_ctrl_barcodes.extend(batch_preds["ctrl_cell_barcode"]) + # else: + # all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) + # all_ctrl_barcodes.append(batch_preds["ctrl_cell_barcode"]) # Handle celltype_name if isinstance(batch_preds["celltype_name"], list): @@ -297,8 +297,8 @@ def load_config(cfg_path: str) -> dict: data_module.pert_col: all_pert_names, data_module.cell_type_key: all_celltypes, data_module.batch_col: all_gem_groups, - "pert_cell_barcode": all_pert_barcodes, - "ctrl_cell_barcode": all_ctrl_barcodes, + # "pert_cell_barcode": all_pert_barcodes, + # "ctrl_cell_barcode": all_ctrl_barcodes, } ) diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py index a36c2167..6409ac23 100644 --- a/src/state/tx/models/base.py +++ b/src/state/tx/models/base.py @@ -154,6 +154,7 @@ def __init__( super().__init__() self.decoder_cfg = decoder_cfg self.save_hyperparameters() + self.gene_decoder_bool = kwargs.get("gene_decoder_bool", True) # Core architecture settings self.input_dim = input_dim @@ -193,6 +194,9 @@ def _build_networks(self): def _build_decoder(self): """Create self.gene_decoder from self.decoder_cfg (or leave None).""" + if self.gene_decoder_bool == False: + self.gene_decoder = None + return if self.decoder_cfg is None: self.gene_decoder = None return @@ -204,6 +208,9 @@ def on_load_checkpoint(self, checkpoint: dict[str, tp.Any]) -> None: Re-create the decoder using the exact hyper-parameters saved in the ckpt, so that parameter shapes match and load_state_dict succeeds. """ + if self.gene_decoder_bool == False: + self.gene_decoder = None + return if "decoder_cfg" in checkpoint["hyper_parameters"]: self.decoder_cfg = checkpoint["hyper_parameters"]["decoder_cfg"] self.gene_decoder = LatentToGeneDecoder(**self.decoder_cfg)