From e82e987b106797e9ef486e4a456c992e24058e3a Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Fri, 19 Sep 2025 19:24:40 +0000 Subject: [PATCH 1/9] heatmap added --- src/state/__main__.py | 4 + src/state/_cli/__init__.py | 2 + src/state/_cli/_tx/__init__.py | 3 + src/state/_cli/_tx/_heatmap.py | 1003 ++++++++++++++++++++++++++++++++ 4 files changed, 1012 insertions(+) create mode 100644 src/state/_cli/_tx/_heatmap.py diff --git a/src/state/__main__.py b/src/state/__main__.py index 0a7f9236..45a26993 100644 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -11,6 +11,7 @@ run_emb_query, run_emb_preprocess, run_emb_eval, + run_tx_heatmap, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -121,6 +122,9 @@ def main(): case "predict": # For now, predict uses argparse and not hydra run_tx_predict(args) + case "heatmap": + # Run heatmap analysis using argparse + run_tx_heatmap(args) case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index 2507d565..da4fc456 100644 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,6 +1,7 @@ from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval from ._tx import ( add_arguments_tx, + run_tx_heatmap, run_tx_infer, run_tx_predict, run_tx_preprocess_infer, @@ -13,6 +14,7 @@ "add_arguments_tx", "run_tx_train", "run_tx_predict", + "run_tx_heatmap", "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 975fba42..e59f2c70 100644 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -1,5 +1,6 @@ import argparse as ap +from ._heatmap import add_arguments_heatmap, run_tx_heatmap from ._infer import add_arguments_infer, run_tx_infer from ._predict import add_arguments_predict, run_tx_predict from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer @@ -9,6 +10,7 @@ __all__ = [ "run_tx_train", "run_tx_predict", + "run_tx_heatmap", "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", @@ -21,6 +23,7 @@ def add_arguments_tx(parser: ap.ArgumentParser): subparsers = parser.add_subparsers(required=True, dest="subcommand") add_arguments_train(subparsers.add_parser("train", add_help=False)) add_arguments_predict(subparsers.add_parser("predict")) + add_arguments_heatmap(subparsers.add_parser("heatmap")) add_arguments_infer(subparsers.add_parser("infer")) add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) diff --git a/src/state/_cli/_tx/_heatmap.py b/src/state/_cli/_tx/_heatmap.py new file mode 100644 index 00000000..1c71623d --- /dev/null +++ b/src/state/_cli/_tx/_heatmap.py @@ -0,0 +1,1003 @@ +import argparse as ap + + +def add_arguments_heatmap(parser: ap.ArgumentParser): + """ + CLI for pathway heatmap analysis with GO MF pathway upregulation. + """ + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Path to the output_dir containing the config.yaml file that was saved during training.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default="last.ckpt", + help="Checkpoint filename. Default is 'last.ckpt'. Relative to the output directory.", + ) + + parser.add_argument( + "--test-time-finetune", + type=int, + default=0, + help="If >0, run test-time fine-tuning for the specified number of epochs on only control cells.", + ) + + parser.add_argument( + "--profile", + type=str, + default="full", + choices=["full", "minimal", "de", "anndata"], + help="run all metrics, minimal, only de metrics, or only output adatas", + ) + + parser.add_argument( + "--predict-only", + action="store_true", + help="If set, only run prediction without evaluation metrics.", + ) + + parser.add_argument( + "--shared-only", + action="store_true", + help=("If set, restrict predictions/evaluation to perturbations shared between train and test (train ∩ test)."), + ) + + parser.add_argument( + "--eval-train-data", + action="store_true", + help="If set, evaluate the model on the training data rather than on the test data.", + ) + + # Optional: apply directional shift on a chosen index using control distributions + parser.add_argument( + "--shift-index", + type=int, + default=None, + help="If set, apply a ±2σ shift to this index across core_cells using control distributions.", + ) + parser.add_argument( + "--shift-direction", + type=str, + default=None, + choices=["up", "down"], + help="Direction for the 2σ shift applied to --shift-index. Requires --shift-index.", + ) + + parser.add_argument( + "--test-time-heat-map", + action="store_true", + help="If set, run test-time heat map analysis with position upregulation.", + ) + parser.add_argument( + "--heatmap-output-path", + type=str, + default=None, + help="Path to save the matplotlib heatmap visualization. If not provided, defaults to /position_upregulation_heatmap.png", + ) + parser.add_argument( + "--results-dir", + type=str, + default=None, + help="Directory to save results. If not provided, defaults to /eval_", + ) + parser.add_argument( + "--annotation-path", + type=str, + default="/home/dhruvgautam/gene_annotations_1_2000.pkl", + help="Path to the hvg gene annotations file.", + ) + + +def run_tx_heatmap(args: ap.ArgumentParser): + import logging + import os + import sys + + import anndata + import lightning.pytorch as pl + import numpy as np + import pandas as pd + import torch + import yaml + import json + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Use non-interactive backend + + # Cell-eval for metrics computation + from cell_eval import MetricsEvaluator + from cell_eval.utils import split_anndata_on_celltype + from cell_load.data_modules import PerturbationDataModule + from tqdm import tqdm + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + torch.multiprocessing.set_sharing_strategy("file_system") + + def run_test_time_finetune(model, dataloader, ft_epochs, control_pert, device): + """ + Perform test-time fine-tuning on only control cells. + """ + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + logger.info(f"Starting test-time fine-tuning for {ft_epochs} epoch(s) on control cells only.") + for epoch in range(ft_epochs): + epoch_losses = [] + pbar = tqdm(dataloader, desc=f"Finetune epoch {epoch + 1}/{ft_epochs}", leave=True) + for batch in pbar: + # Check if this batch contains control cells + first_pert = ( + batch["pert_name"][0] if isinstance(batch["pert_name"], list) else batch["pert_name"][0].item() + ) + if first_pert != control_pert: + continue + + # Move batch data to device + batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} + + optimizer.zero_grad() + loss = model.training_step(batch, batch_idx=0, padded=False) + if loss is None: + continue + loss.backward() + optimizer.step() + epoch_losses.append(loss.item()) + pbar.set_postfix(loss=f"{loss.item():.4f}") + + mean_loss = np.mean(epoch_losses) if epoch_losses else float("nan") + logger.info(f"Finetune epoch {epoch + 1}/{ft_epochs}, mean loss: {mean_loss}") + model.eval() + + def load_config(cfg_path: str) -> dict: + """Load config from the YAML file that was dumped during training.""" + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + return cfg + + # 1. Load the config + config_path = os.path.join(args.output_dir, "config.yaml") + cfg = load_config(config_path) + logger.info(f"Loaded config from {config_path}") + + # 2. Find run output directory & load data module + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}?") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + + # Seed everything + pl.seed_everything(cfg["training"]["train_seed"]) + + # 3. Load the trained model + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + checkpoint_path = os.path.join(checkpoint_dir, args.checkpoint) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Could not find checkpoint at {checkpoint_path}.\nSpecify a correct checkpoint filename with --checkpoint." + ) + logger.info("Loading model from %s", checkpoint_path) + + # Determine model class and load + model_class_name = cfg["model"]["name"] + model_kwargs = cfg["model"]["kwargs"] + + # Import the correct model class + if model_class_name.lower() == "embedsum": + from ...tx.models.embed_sum import EmbedSumPerturbationModel + + ModelClass = EmbedSumPerturbationModel + elif model_class_name.lower() == "old_neuralot": + from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel + + ModelClass = OldNeuralOTPerturbationModel + elif model_class_name.lower() in ["neuralot", "pertsets", "state"]: + from ...tx.models.state_transition import StateTransitionPerturbationModel + + ModelClass = StateTransitionPerturbationModel + + elif model_class_name.lower() in ["globalsimplesum", "perturb_mean"]: + from ...tx.models.perturb_mean import PerturbMeanPerturbationModel + + ModelClass = PerturbMeanPerturbationModel + elif model_class_name.lower() in ["celltypemean", "context_mean"]: + from ...tx.models.context_mean import ContextMeanPerturbationModel + + ModelClass = ContextMeanPerturbationModel + elif model_class_name.lower() == "decoder_only": + from ...tx.models.decoder_only import DecoderOnlyPerturbationModel + + ModelClass = DecoderOnlyPerturbationModel + else: + raise ValueError(f"Unknown model class: {model_class_name}") + + var_dims = data_module.get_var_dims() + model_init_kwargs = { + "input_dim": var_dims["input_dim"], + "hidden_dim": model_kwargs["hidden_dim"], + "gene_dim": var_dims["gene_dim"], + "hvg_dim": var_dims["hvg_dim"], + "output_dim": var_dims["output_dim"], + "pert_dim": var_dims["pert_dim"], + **model_kwargs, + } + + model = ModelClass.load_from_checkpoint(checkpoint_path, **model_init_kwargs) + model.eval() + logger.info("Model loaded successfully.") + + # 4. Test-time fine-tuning if requested + data_module.batch_size = 1 + if args.test_time_finetune > 0: + control_pert = data_module.get_control_pert() + if args.eval_train_data: + test_loader = data_module.train_dataloader(test=True) + else: + test_loader = data_module.test_dataloader() + + run_test_time_finetune( + model, test_loader, args.test_time_finetune, control_pert, device=next(model.parameters()).device + ) + logger.info("Test-time fine-tuning complete.") + + # 5. Run inference on test set + data_module.setup(stage="test") + if args.eval_train_data: + scan_loader = data_module.train_dataloader(test=True) + else: + scan_loader = data_module.test_dataloader() + + if scan_loader is None: + logger.warning("No test dataloader found. Exiting.") + sys.exit(0) + + logger.info("Preparing a fixed batch of 64 control cells (core_cells) and enumerating perturbations...") + + # Helper to normalize values to python lists + def _to_list(value): + if isinstance(value, list): + return value + if isinstance(value, torch.Tensor): + try: + return [x.item() if x.dim() == 0 else x for x in value] + except Exception: + return value.tolist() + return [value] + + control_pert = data_module.get_control_pert() + + # Collect unique perturbation names from the loader without running the model + unique_perts = [] + seen_perts = set() + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + for n in names: + if isinstance(n, torch.Tensor): + try: + n = n.item() + except Exception: + n = str(n) + if n not in seen_perts: + seen_perts.add(n) + unique_perts.append(n) + + if control_pert in seen_perts: + logger.info(f"Found {len(unique_perts)} total perturbations (including control '{control_pert}').") + else: + logger.warning("Control perturbation not observed in test loader perturbation names.") + + # Build a single fixed batch of exactly 64 control cells + target_core_n = 64 + core_cells = None + accum = {} + + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + # Iterate again to collect control cells only + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + # Build a mask for control entries when possible + mask = None + if len(names) > 0: + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + else: + # If no names provided in batch, skip (cannot verify control) + continue + + # Slice each tensor field by mask and accumulate until we have 64 + current_count = 0 if "_count" not in accum else accum["_count"] + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + # Identify keys to carry forward; prefer tensors and essential metadata + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + try: + vsel = v[mask][:take].detach().clone() + except Exception: + # fallback: try first dimension slice + vsel = v[:take].detach().clone() + _append_field(accum, k, vsel) + else: + # For non-tensor fields, convert to list and slice by mask when possible + vals = _to_list(v) + try: + selected_vals = [vals[i] for i, m in enumerate(mask.tolist()) if m][:take] + except Exception: + selected_vals = vals[:take] + _append_field(accum, k, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError(f"Could not assemble {target_core_n} control cells for core_cells; gathered {accum.get('_count', 0)}.") + + # Collate accumulated pieces into a single batch dict of length 64 + core_cells = {} + for k, parts in accum.items(): + if k == "_count": + continue + if len(parts) == 1: + val = parts[0] + else: + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for p in parts: + merged.extend(_to_list(p)) + val = merged + # Ensure final length == 64 + if isinstance(val, torch.Tensor): + core_cells[k] = val[:target_core_n] + else: + core_cells[k] = _to_list(val)[:target_core_n] + + logger.info(f"Constructed core_cells batch with size {target_core_n}.") + + # Compute distributions for each position across ALL control cells in the test loader + # Strategy: determine a 2D vector key from the first batch, then aggregate all control rows + vector_key_candidates = ["ctrl_cell_emb", "pert_cell_emb", "X"] + dist_source_key = None + # Find key by peeking one batch + for b in scan_loader: + for cand in vector_key_candidates: + if cand in b and isinstance(b[cand], torch.Tensor) and b[cand].dim() == 2: + dist_source_key = cand + break + if dist_source_key is None: + # fallback: any 2D tensor + for k, v in b.items(): + if isinstance(v, torch.Tensor) and v.dim() == 2: + dist_source_key = k + break + # break after first batch inspected + break + if dist_source_key is None: + raise RuntimeError("Could not find a 2D tensor in test loader batches to compute per-dimension distributions.") + + # Aggregate all control rows for the chosen key + control_rows = [] + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + if len(names) == 0: + continue + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + vec = batch.get(dist_source_key, None) + if isinstance(vec, torch.Tensor) and vec.dim() == 2: + try: + control_rows.append(vec[mask].detach().cpu().float()) + except Exception: + # fallback: take leading rows equal to mask sum + take = int(mask.sum().item()) + if take > 0: + control_rows.append(vec[:take].detach().cpu().float()) + + if len(control_rows) == 0: + raise RuntimeError("No control rows found to compute distributions.") + + control_vectors_all = torch.cat(control_rows, dim=0) # [Nc, D] + D = control_vectors_all.shape[1] + if D != 2000: + logger.warning(f"Expected vector dimension 2000; found {D}. Proceeding with {D} dimensions.") + + control_mean = control_vectors_all.mean(dim=0) + control_std = control_vectors_all.std(dim=0, unbiased=False).clamp_min(1e-8) + + # Save distributions to results directory later; keep in scope for optional shifting + distributions = { + "key": dist_source_key, + "mean": control_mean.numpy(), + "std": control_std.numpy(), + "dim": int(D), + "num_cells": int(control_vectors_all.shape[0]), + } + + def apply_shift_to_core_cells(index: int, upregulate: bool): + """Apply ±2σ shift at a single index across all vectors in core_cells. + + - index: integer in [0, D) + - upregulate: True for +2σ, False for -2σ + Operates in-place on the tensor stored at distributions['key'] inside core_cells. + """ + nonlocal core_cells, distributions + if index < 0 or index >= distributions["dim"]: + raise ValueError(f"Index {index} is out of bounds for dimension {distributions['dim']}") + shift_value = (2.0 if upregulate else -2.0) * float(distributions["std"][index]) + key = distributions["key"] + tensor = core_cells[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + tensor[:, index] = tensor[:, index] + shift_value + core_cells[key] = tensor + + # Optionally apply shift based on CLI flags before running inference + if args.shift_index is not None: + if args.shift_direction is None: + raise ValueError("--shift-direction is required when --shift-index is provided") + apply_shift_to_core_cells(index=int(args.shift_index), upregulate=(args.shift_direction == "up")) + logger.info(f"Applied 2σ {'up' if args.shift_direction=='up' else 'down'} shift at index {int(args.shift_index)} across core_cells") + + # Prepare output arrays sized by num_perts * 64 + # Keep all perturbations including control to be explicit + perts_order = list(unique_perts) + num_cells = len(perts_order) * target_core_n + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + + logger.info("Generating predictions: one forward pass per perturbation on core_cells...") + device = next(model.parameters()).device + + # Prepare perturbation one-hot/embedding map for the pert encoder + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + try: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + else: + logger.warning(f"pert_onehot_map.pt not found at {map_path}; proceeding without explicit pert_emb overrides") + pert_onehot_map = {} + except Exception as e: + logger.warning(f"Failed to load pert_onehot_map.pt: {e}") + pert_onehot_map = {} + + def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): + vec = None + try: + vec = pert_onehot_map.get(pert_name, None) + if vec is None and control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + except Exception: + vec = None + if vec is None: + # Fallback to zeros with model.pert_dim if mapping is unavailable + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("Could not determine pert_dim to build pert_emb") + vec = torch.zeros(pert_dim) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) + + # Phase 1: Normal inference on all perturbations + final_preds = np.empty((num_cells, output_dim), dtype=np.float32) + final_reals = np.empty((num_cells, output_dim), dtype=np.float32) + + # Phase 2: Store normal predictions for distance computation + normal_preds_per_pert = {} # pert_name -> [64, output_dim] array + + store_raw_expression = ( + data_module.embed_key is not None + and data_module.embed_key != "X_hvg" + and cfg["data"]["kwargs"]["output_space"] == "gene" + ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") + + final_X_hvg = None + final_pert_cell_counts_preds = None + if store_raw_expression: + # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. + if cfg["data"]["kwargs"]["output_space"] == "gene": + final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) + if cfg["data"]["kwargs"]["output_space"] == "all": + final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) + + current_idx = 0 + + # Initialize aggregation variables directly + all_pert_names = [] + all_celltypes = [] + all_gem_groups = [] + all_pert_barcodes = [] + all_ctrl_barcodes = [] + + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="Predicting", unit="pert")): + # Build a batch by copying core_cells and swapping perturbation + batch = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields to target pert + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + # Best-effort: update any index fields if present and mapping exists + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + batch_preds = model.predict_step(batch, p_idx, padded=False) + + # Extract metadata and data directly from batch_preds + # Handle pert_name + batch_pert_names = [] + if isinstance(batch_preds["pert_name"], list): + all_pert_names.extend(batch_preds["pert_name"]) + batch_pert_names = batch_preds["pert_name"] + else: + all_pert_names.append(batch_preds["pert_name"]) + batch_pert_names = [batch_preds["pert_name"]] + + if "pert_cell_barcode" in batch_preds: + if isinstance(batch_preds["pert_cell_barcode"], list): + all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.extend(batch_preds.get("ctrl_cell_barcode", [None] * len(batch_preds["pert_cell_barcode"])) ) + else: + all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.append(batch_preds.get("ctrl_cell_barcode", None)) + + # Handle celltype_name + if isinstance(batch_preds["celltype_name"], list): + all_celltypes.extend(batch_preds["celltype_name"]) + else: + all_celltypes.append(batch_preds["celltype_name"]) + + # Handle gem_group + if isinstance(batch_preds["batch"], list): + all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) + elif isinstance(batch_preds["batch"], torch.Tensor): + all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) + else: + all_gem_groups.append(str(batch_preds["batch"])) + + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + batch_size = batch_pred_np.shape[0] + final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np + final_reals[current_idx : current_idx + batch_size, :] = batch_real_np + + # Store normal predictions for this perturbation for distance computation + normal_preds_per_pert[pert] = batch_pred_np.copy() + + current_idx += batch_size + + # Handle X_hvg for HVG space ground truth + if final_X_hvg is not None: + batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) + final_X_hvg[current_idx - batch_size : current_idx, :] = batch_real_gene_np + + # Handle decoded gene predictions if available + if final_pert_cell_counts_preds is not None: + batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) + final_pert_cell_counts_preds[current_idx - batch_size : current_idx, :] = batch_gene_pred_np + + logger.info("Phase 1 complete: Normal inference on all perturbations.") + + # Phase 2: Run inference with GO MF pathway groups upregulated (only if requested) + if args.test_time_heat_map: + logger.info("Phase 2: Loading GO MF pathway annotations and running pathway-based upregulation...") + + # Load gene annotations + import pickle + with open(args.annotation_path, 'rb') as f: + gene_annotations = pickle.load(f) + + # Group genes by GO MF pathways + from collections import defaultdict + pathway_to_genes = defaultdict(list) + + for idx, data in gene_annotations.items(): + mf_paths = data['go_cc_paths'] + if mf_paths: # If gene has MF pathways + pathways = mf_paths.split(';') + for pathway in pathways: + # Convert 1-indexed to 0-indexed + pathway_to_genes[pathway].append(idx - 1) + + # Filter out pathways with too few genes (less than 3) to avoid noise + filtered_pathways = {pathway: genes for pathway, genes in pathway_to_genes.items() if len(genes) >= 3} + + logger.info(f"Found {len(pathway_to_genes)} total GO MF pathways") + logger.info(f"Using {len(filtered_pathways)} pathways with 3+ genes for upregulation") + + # Initialize heatmap array: [num_pathways, num_perturbations] + num_pathways = len(filtered_pathways) + heatmap_distances = np.zeros((num_pathways, len(perts_order)), dtype=np.float32) + pathway_names = list(filtered_pathways.keys()) + + # Create a copy of core_cells for upregulation experiments + original_core_cells = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + original_core_cells[k] = v.clone() + else: + original_core_cells[k] = v.copy() if hasattr(v, 'copy') else v + + def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, target_norm: float = 2.0): + """Apply shift to multiple gene indices with equivalent euclidean norm across pathways. + + This function ensures that all pathways receive the same euclidean norm perturbation: + 1. Compute individual shifts based on 2σ for each gene + 2. Calculate the euclidean norm of the shift vector + 3. Rescale the entire shift vector to match the target euclidean norm + + - gene_indices: list of 0-indexed gene positions + - upregulate: True for positive shift, False for negative shift + - target_norm: target euclidean norm for the perturbation (default: 2.0) + Operates in-place on the tensor stored at distributions['key'] inside core_cells. + """ + nonlocal core_cells, distributions + key = distributions['key'] + tensor = core_cells[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + + if len(gene_indices) == 0: + return + + # Step 1: Compute raw shift values based on 2σ for each gene + raw_shifts = {} + for idx in gene_indices: + if 0 <= idx < distributions["dim"]: + base_shift = 2.0 * float(distributions["std"][idx]) + raw_shifts[idx] = base_shift if upregulate else -base_shift + + if len(raw_shifts) == 0: + return + + # Step 2: Calculate euclidean norm of the raw shift vector + shift_values = np.array(list(raw_shifts.values())) + current_norm = np.linalg.norm(shift_values) + + # Step 3: Rescale to target norm if current norm > 0 + if current_norm > 1e-8: # Avoid division by zero + scale_factor = target_norm / current_norm + + # Apply rescaled shifts + for idx, raw_shift in raw_shifts.items(): + scaled_shift = raw_shift * scale_factor + tensor[:, idx] = tensor[:, idx] + scaled_shift + else: + # Fallback: if all std deviations are zero, apply uniform shift + uniform_shift = target_norm / np.sqrt(len(raw_shifts)) + for idx in raw_shifts.keys(): + shift_value = uniform_shift if upregulate else -uniform_shift + tensor[:, idx] = tensor[:, idx] + shift_value + + with torch.no_grad(): + for pathway_idx, (pathway_name, gene_indices) in enumerate(tqdm(filtered_pathways.items(), desc="Upregulating pathways", unit="pathway")): + # Apply downregulation to all genes in this pathway + apply_pathway_shift_to_core_cells(gene_indices, upregulate=True) + + # Run inference for all perturbations with this pathway upregulated + for p_idx, pert in enumerate(perts_order): + # Build batch by copying upregulated core_cells + batch = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + # Get predictions with upregulated pathway + batch_preds = model.predict_step(batch, p_idx, padded=False) + upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + + # Compute euclidean distance between normal and upregulated predictions + normal_preds = normal_preds_per_pert[pert] # [64, output_dim] + distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 64 cells + heatmap_distances[pathway_idx, p_idx] = distance + + # Restore original core_cells for next pathway + for k, v in original_core_cells.items(): + if isinstance(v, torch.Tensor): + core_cells[k] = v.clone() + else: + core_cells[k] = v.copy() if hasattr(v, 'copy') else v + + logger.info("Phase 2 complete: Upregulated inference for all GO MF pathways.") + + # Save heatmap data + try: + # Determine results directory + if args.results_dir is not None: + results_dir = args.results_dir + else: + results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + os.makedirs(results_dir, exist_ok=True) + + heatmap_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.npy") + np.save(heatmap_path, heatmap_distances) + + # Save pathway information + pathway_info_path = os.path.join(results_dir, "go_cc_pathways_info.json") + pathway_info = { + "pathway_names": pathway_names, + "pathway_to_genes": {pathway: genes for pathway, genes in filtered_pathways.items()}, + "total_pathways": len(pathway_to_genes), + "filtered_pathways": len(filtered_pathways), + "min_genes_per_pathway": 3 + } + with open(pathway_info_path, "w") as f: + json.dump(pathway_info, f, indent=2) + + # Save metadata for the heatmap + heatmap_meta = { + "shape": [num_pathways, len(perts_order)], + "description": "Euclidean distance heatmap: rows=GO MF pathways, cols=perturbations", + "perturbations": perts_order, + "pathway_names": pathway_names, + "distance_type": "mean_euclidean_norm_across_64_cells", + "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene" + } + heatmap_meta_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.meta.json") + with open(heatmap_meta_path, "w") as f: + json.dump(heatmap_meta, f, indent=2) + + logger.info(f"Saved GO MF pathway upregulation heatmap to {heatmap_path}") + logger.info(f"Heatmap shape: {heatmap_distances.shape} (pathways x perturbations)") + except Exception as e: + logger.warning(f"Failed to save heatmap data: {e}") + + # Create and save matplotlib heatmap visualization + try: + # Determine output path for heatmap image + if args.heatmap_output_path is not None: + heatmap_img_path = args.heatmap_output_path + else: + heatmap_img_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.png") + + # Ensure directory exists + os.makedirs(os.path.dirname(heatmap_img_path), exist_ok=True) + + # Create the heatmap with appropriate size + fig_width = max(12, len(perts_order) * 0.3) + fig_height = max(8, num_pathways * 0.05) # Smaller height per pathway since we have fewer rows + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # Create heatmap with proper labels + im = ax.imshow(heatmap_distances, cmap='viridis', aspect='auto') + + # Set labels and title + ax.set_xlabel('Perturbations') + ax.set_ylabel('GO MF Pathways') + ax.set_title('GO MF Pathway Upregulation Impact Heatmap\n(Euclidean Distance from Normal Predictions)') + + # Set x-axis labels (perturbations) + ax.set_xticks(range(len(perts_order))) + ax.set_xticklabels(perts_order, rotation=45, ha='right', fontsize=8) + + # Set y-axis labels (pathways) - show pathway names, truncated if too long + ax.set_yticks(range(num_pathways)) + truncated_pathway_names = [] + for pathway_name in pathway_names: + # Remove GOMF_ prefix and truncate long names + clean_name = pathway_name.replace('GOMF_', '') + if len(clean_name) > 30: + clean_name = clean_name[:27] + '...' + truncated_pathway_names.append(clean_name) + ax.set_yticklabels(truncated_pathway_names, fontsize=6) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label('Mean Euclidean Distance', rotation=270, labelpad=20) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the figure + plt.savefig(heatmap_img_path, dpi=300, bbox_inches='tight') + plt.close(fig) # Close to free memory + + logger.info(f"Saved GO MF pathway heatmap visualization to {heatmap_img_path}") + + except Exception as e: + logger.warning(f"Failed to create heatmap visualization: {e}") + else: + logger.info("Skipping heatmap analysis (--test-time-heat-map not set)") + + logger.info("Creating anndatas from predictions from manual loop...") + + # Build pandas DataFrame for obs and var + df_dict = { + data_module.pert_col: all_pert_names, + data_module.cell_type_key: all_celltypes, + data_module.batch_col: all_gem_groups, + } + + if len(all_pert_barcodes) > 0: + df_dict["pert_cell_barcode"] = all_pert_barcodes + df_dict["ctrl_cell_barcode"] = all_ctrl_barcodes + + obs = pd.DataFrame(df_dict) + + gene_names = var_dims["gene_names"] + var = pd.DataFrame({"gene_names": gene_names}) + + if final_X_hvg is not None: + if len(gene_names) != final_pert_cell_counts_preds.shape[1]: + gene_names = np.load( + "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + ) + var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - using the decoded gene expression values + adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs, var=var) + # Create adata for real - using the true gene expression values + adata_real = anndata.AnnData(X=final_X_hvg, obs=obs, var=var) + + # add the embedding predictions + adata_pred.obsm[data_module.embed_key] = final_preds + adata_real.obsm[data_module.embed_key] = final_reals + logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + else: + # if len(gene_names) != final_preds.shape[1]: + # gene_names = np.load( + # "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + # ) + # var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - model was trained on gene expression space already + # adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) + adata_pred = anndata.AnnData(X=final_preds, obs=obs) + # Create adata for real - using the true gene expression values + # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) + adata_real = anndata.AnnData(X=final_reals, obs=obs) + + # Optionally filter to perturbations seen in at least one training context + if args.shared_only: + try: + shared_perts = data_module.get_shared_perturbations() + if len(shared_perts) == 0: + logger.warning("No shared perturbations between train and test; skipping filtering.") + else: + logger.info( + "Filtering to %d shared perturbations present in train ∩ test.", + len(shared_perts), + ) + mask = adata_pred.obs[data_module.pert_col].isin(shared_perts) + before_n = adata_pred.n_obs + adata_pred = adata_pred[mask].copy() + adata_real = adata_real[mask].copy() + logger.info( + "Filtered cells: %d -> %d (kept only seen perturbations)", + before_n, + adata_pred.n_obs, + ) + except Exception as e: + logger.warning( + "Failed to filter by shared perturbations (%s). Proceeding without filter.", + str(e), + ) + + # Save the AnnData objects + results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + os.makedirs(results_dir, exist_ok=True) + adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") + adata_real_path = os.path.join(results_dir, "adata_real.h5ad") + + adata_pred.write_h5ad(adata_pred_path) + adata_real.write_h5ad(adata_real_path) + + logger.info(f"Saved adata_pred to {adata_pred_path}") + logger.info(f"Saved adata_real to {adata_real_path}") + + # Save per-dimension control-cell distributions for reproducibility + try: + dist_out = { + "key": distributions["key"], + "dim": distributions["dim"], + "num_cells": distributions["num_cells"], + } + dist_out_path = os.path.join(results_dir, "control_distributions.meta.json") + with open(dist_out_path, "w") as f: + json.dump(dist_out, f) + np.save(os.path.join(results_dir, "control_mean.npy"), distributions["mean"]) # [D] + np.save(os.path.join(results_dir, "control_std.npy"), distributions["std"]) # [D] + logger.info("Saved control-cell per-dimension mean/std distributions") + except Exception as e: + logger.warning(f"Failed to save control-cell distributions: {e}") + + if not args.predict_only: + # 6. Compute metrics using cell-eval + logger.info("Computing metrics using cell-eval...") + + control_pert = data_module.get_control_pert() + + ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key) + ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key) + + assert len(ct_split_real) == len(ct_split_pred), ( + f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" + ) + + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for ct in ct_split_real.keys(): + real_ct = ct_split_real[ct] + pred_ct = ct_split_pred[ct] + + evaluator = MetricsEvaluator( + adata_pred=pred_ct, + adata_real=real_ct, + control_pert=control_pert, + pert_col=data_module.pert_col, + outdir=results_dir, + prefix=ct, + pdex_kwargs=pdex_kwargs, + batch_size=2048, + ) + + evaluator.compute( + profile=args.profile, + metric_configs={ + "discrimination_score": { + "embed_key": data_module.embed_key, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + "pearson_edistance": { + "embed_key": data_module.embed_key, + "n_jobs": -1, # set to all available cores + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else { + "n_jobs": -1, + }, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + skip_metrics=["pearson_edistance", "clustering_agreement"], + ) From 48917567db6dfae557afda96573b7b941b95fee0 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Fri, 19 Sep 2025 21:00:04 +0000 Subject: [PATCH 2/9] bump semvar --- pyproject.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pyproject.toml b/pyproject.toml index cbac9752..d5a1c54f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [project] name = "arc-state" -version = "0.9.27" +version = "0.9.28" description = "State is a machine learning model that predicts cellular perturbation response across diverse contexts." readme = "README.md" authors = [ From 0cd1990540621685a92b58ac8ddd773f25d1f3f5 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Fri, 19 Sep 2025 21:06:19 +0000 Subject: [PATCH 3/9] annotation field --- src/state/_cli/_tx/_heatmap.py | 51 +++++++++++++++++++++------------- 1 file changed, 32 insertions(+), 19 deletions(-) diff --git a/src/state/_cli/_tx/_heatmap.py b/src/state/_cli/_tx/_heatmap.py index 1c71623d..bf9fb1d9 100644 --- a/src/state/_cli/_tx/_heatmap.py +++ b/src/state/_cli/_tx/_heatmap.py @@ -90,7 +90,13 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): default="/home/dhruvgautam/gene_annotations_1_2000.pkl", help="Path to the hvg gene annotations file.", ) - + parser.add_argument( + "--annotation-field", + type=str, + default="go_cc_paths", + help="Field name in the annotation data to use for pathway grouping (e.g., 'go_cc_paths', 'go_mf_paths', 'go_bp_paths', etc.).", + ) + def run_tx_heatmap(args: ap.ArgumentParser): import logging @@ -622,14 +628,14 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): with open(args.annotation_path, 'rb') as f: gene_annotations = pickle.load(f) - # Group genes by GO MF pathways + # Group genes by pathways using the specified annotation field from collections import defaultdict pathway_to_genes = defaultdict(list) for idx, data in gene_annotations.items(): - mf_paths = data['go_cc_paths'] - if mf_paths: # If gene has MF pathways - pathways = mf_paths.split(';') + pathway_data = data[args.annotation_field] + if pathway_data: # If gene has pathways + pathways = pathway_data.split(';') for pathway in pathways: # Convert 1-indexed to 0-indexed pathway_to_genes[pathway].append(idx - 1) @@ -637,7 +643,7 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): # Filter out pathways with too few genes (less than 3) to avoid noise filtered_pathways = {pathway: genes for pathway, genes in pathway_to_genes.items() if len(genes) >= 3} - logger.info(f"Found {len(pathway_to_genes)} total GO MF pathways") + logger.info(f"Found {len(pathway_to_genes)} total pathways from field '{args.annotation_field}'") logger.info(f"Using {len(filtered_pathways)} pathways with 3+ genes for upregulation") # Initialize heatmap array: [num_pathways, num_perturbations] @@ -748,7 +754,10 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ else: core_cells[k] = v.copy() if hasattr(v, 'copy') else v - logger.info("Phase 2 complete: Upregulated inference for all GO MF pathways.") + logger.info(f"Phase 2 complete: Upregulated inference for all pathways from field '{args.annotation_field}'.") + + # Create filename based on annotation field + field_suffix = args.annotation_field.replace('_', '').lower() # Save heatmap data try: @@ -759,11 +768,11 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) os.makedirs(results_dir, exist_ok=True) - heatmap_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.npy") + heatmap_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") np.save(heatmap_path, heatmap_distances) # Save pathway information - pathway_info_path = os.path.join(results_dir, "go_cc_pathways_info.json") + pathway_info_path = os.path.join(results_dir, f"{field_suffix}_pathways_info.json") pathway_info = { "pathway_names": pathway_names, "pathway_to_genes": {pathway: genes for pathway, genes in filtered_pathways.items()}, @@ -777,17 +786,18 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ # Save metadata for the heatmap heatmap_meta = { "shape": [num_pathways, len(perts_order)], - "description": "Euclidean distance heatmap: rows=GO MF pathways, cols=perturbations", + "description": f"Euclidean distance heatmap: rows={args.annotation_field} pathways, cols=perturbations", "perturbations": perts_order, "pathway_names": pathway_names, "distance_type": "mean_euclidean_norm_across_64_cells", - "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene" + "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene", + "annotation_field": args.annotation_field } - heatmap_meta_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.meta.json") + heatmap_meta_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") with open(heatmap_meta_path, "w") as f: json.dump(heatmap_meta, f, indent=2) - logger.info(f"Saved GO MF pathway upregulation heatmap to {heatmap_path}") + logger.info(f"Saved {args.annotation_field} pathway upregulation heatmap to {heatmap_path}") logger.info(f"Heatmap shape: {heatmap_distances.shape} (pathways x perturbations)") except Exception as e: logger.warning(f"Failed to save heatmap data: {e}") @@ -798,7 +808,7 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ if args.heatmap_output_path is not None: heatmap_img_path = args.heatmap_output_path else: - heatmap_img_path = os.path.join(results_dir, "go_cc_pathway_upregulation_heatmap.png") + heatmap_img_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.png") # Ensure directory exists os.makedirs(os.path.dirname(heatmap_img_path), exist_ok=True) @@ -813,8 +823,8 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ # Set labels and title ax.set_xlabel('Perturbations') - ax.set_ylabel('GO MF Pathways') - ax.set_title('GO MF Pathway Upregulation Impact Heatmap\n(Euclidean Distance from Normal Predictions)') + ax.set_ylabel(f'{args.annotation_field.replace("_", " ").title()} Pathways') + ax.set_title(f'{args.annotation_field.replace("_", " ").title()} Pathway Upregulation Impact Heatmap\n(Euclidean Distance from Normal Predictions)') # Set x-axis labels (perturbations) ax.set_xticks(range(len(perts_order))) @@ -824,8 +834,11 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ ax.set_yticks(range(num_pathways)) truncated_pathway_names = [] for pathway_name in pathway_names: - # Remove GOMF_ prefix and truncate long names - clean_name = pathway_name.replace('GOMF_', '') + # Remove common prefixes and truncate long names + clean_name = pathway_name + # Remove common GO prefixes + for prefix in ['GOMF_', 'GOCC_', 'GOBP_']: + clean_name = clean_name.replace(prefix, '') if len(clean_name) > 30: clean_name = clean_name[:27] + '...' truncated_pathway_names.append(clean_name) @@ -842,7 +855,7 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ plt.savefig(heatmap_img_path, dpi=300, bbox_inches='tight') plt.close(fig) # Close to free memory - logger.info(f"Saved GO MF pathway heatmap visualization to {heatmap_img_path}") + logger.info(f"Saved {args.annotation_field} pathway heatmap visualization to {heatmap_img_path}") except Exception as e: logger.warning(f"Failed to create heatmap visualization: {e}") From 277dd259d3f88f4ba54ff01a5079751342972280 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Thu, 25 Sep 2025 02:11:24 +0000 Subject: [PATCH 4/9] changes --- .github/CODEOWNERS | 0 .github/workflows/release.yml | 0 .gitignore | 0 .gitmodules | 0 .python-version | 0 LICENSE | 0 MODEL_ACCEPTABLE_USE_POLICY.md | 0 MODEL_LICENSE.md | 0 README.md | 0 assets/generalization_task.png | Bin examples/fewshot.toml | 0 examples/mixed.toml | 0 examples/random.h5ad | Bin examples/zeroshot.toml | 0 pyproject.toml | 0 ruff.toml | 0 scripts/state_embed_anndata.py | 0 singularity.def | 0 src/state/__init__.py | 0 src/state/__main__.py | 0 src/state/_cli/__init__.py | 0 src/state/_cli/_emb/__init__.py | 0 src/state/_cli/_emb/_eval.py | 0 src/state/_cli/_emb/_fit.py | 0 src/state/_cli/_emb/_preprocess.py | 0 src/state/_cli/_emb/_query.py | 0 src/state/_cli/_emb/_transform.py | 0 src/state/_cli/_tx/__init__.py | 0 src/state/_cli/_tx/_heatmap.py | 214 +++++++++++++++--- src/state/_cli/_tx/_infer.py | 0 src/state/_cli/_tx/_predict.py | 0 src/state/_cli/_tx/_preprocess_infer.py | 0 src/state/_cli/_tx/_preprocess_train.py | 0 src/state/_cli/_tx/_train.py | 0 src/state/configs/__init__.py | 0 src/state/configs/config.yaml | 0 src/state/configs/data/default.yaml | 0 src/state/configs/data/perturbation.yaml | 0 src/state/configs/model/celltypemean.yaml | 0 src/state/configs/model/context_mean.yaml | 0 src/state/configs/model/cpa.yaml | 0 src/state/configs/model/decoder_only.yaml | 0 src/state/configs/model/embedsum.yaml | 0 src/state/configs/model/globalsimplesum.yaml | 0 src/state/configs/model/old_neuralot.yaml | 0 src/state/configs/model/pertsets.yaml | 0 src/state/configs/model/perturb_mean.yaml | 0 src/state/configs/model/scgpt-chemical.yaml | 0 src/state/configs/model/scgpt-genetic.yaml | 0 src/state/configs/model/scvi.yaml | 0 src/state/configs/model/state.yaml | 0 src/state/configs/model/state_lg.yaml | 0 src/state/configs/model/state_sm.yaml | 0 src/state/configs/model/tahoe_best.yaml | 0 .../configs/model/tahoe_llama_212693232.yaml | 0 .../configs/model/tahoe_llama_62089464.yaml | 0 src/state/configs/state-defaults.yaml | 0 src/state/configs/training/cpa.yaml | 0 src/state/configs/training/default.yaml | 0 src/state/configs/training/scgpt.yaml | 0 src/state/configs/training/scvi.yaml | 0 src/state/configs/wandb/default.yaml | 0 src/state/emb/__init__.py | 0 src/state/emb/data/__init__.py | 0 src/state/emb/data/loader.py | 0 src/state/emb/eval/__init__.py | 0 src/state/emb/eval/emb.py | 0 src/state/emb/finetune_decoder.py | 0 src/state/emb/inference.py | 0 src/state/emb/nn/__init__.py | 0 src/state/emb/nn/eval_utils.py | 0 src/state/emb/nn/flash_transformer.py | 0 src/state/emb/nn/loss.py | 0 src/state/emb/nn/model.py | 0 src/state/emb/tools/__init__.py | 0 src/state/emb/tools/slurm.py | 0 src/state/emb/train/__init__.py | 0 src/state/emb/train/__main__.py | 0 src/state/emb/train/callbacks.py | 0 src/state/emb/train/trainer.py | 0 src/state/emb/utils.py | 0 src/state/emb/vectordb.py | 0 src/state/py.typed | 0 src/state/tx/__init__.py | 0 src/state/tx/callbacks/__init__.py | 0 src/state/tx/callbacks/batch_speed_monitor.py | 0 src/state/tx/callbacks/cumulative_flops.py | 0 .../tx/callbacks/model_flops_utilization.py | 0 src/state/tx/data/dataset/__init__.py | 0 .../dataset/scgpt_perturbation_dataset.py | 0 src/state/tx/models/__init__.py | 0 src/state/tx/models/base.py | 0 src/state/tx/models/context_mean.py | 0 src/state/tx/models/cpa/__init__.py | 0 src/state/tx/models/cpa/_base_modules.py | 0 src/state/tx/models/cpa/_callbacks.py | 0 src/state/tx/models/cpa/_dists.py | 0 src/state/tx/models/cpa/_model.py | 0 src/state/tx/models/cpa/_module.py | 0 src/state/tx/models/cpa/_task.py | 0 src/state/tx/models/decoder_only.py | 0 src/state/tx/models/decoders.py | 0 src/state/tx/models/decoders_nb.py | 0 src/state/tx/models/embed_sum.py | 0 src/state/tx/models/old_neural_ot.py | 0 src/state/tx/models/perturb_mean.py | 0 src/state/tx/models/pseudobulk.py | 0 src/state/tx/models/scgpt/__init__.py | 0 src/state/tx/models/scgpt/dsbn.py | 0 src/state/tx/models/scgpt/gene_tokenizer.py | 0 src/state/tx/models/scgpt/generation_model.py | 0 src/state/tx/models/scgpt/grad_reverse.py | 0 src/state/tx/models/scgpt/lightning_model.py | 0 src/state/tx/models/scgpt/loss.py | 0 src/state/tx/models/scgpt/model.py | 0 src/state/tx/models/scgpt/utils.py | 0 src/state/tx/models/scvi/__init__.py | 0 src/state/tx/models/scvi/_base_modules.py | 0 src/state/tx/models/scvi/_callbacks.py | 0 src/state/tx/models/scvi/_dists.py | 0 src/state/tx/models/scvi/_model.py | 0 src/state/tx/models/scvi/_module.py | 0 src/state/tx/models/scvi/_task.py | 0 src/state/tx/models/state_transition.py | 0 src/state/tx/models/utils.py | 0 src/state/tx/utils/__init__.py | 0 src/state/tx/utils/singleton.py | 0 tests/test_callbacks.py | 0 128 files changed, 186 insertions(+), 28 deletions(-) mode change 100644 => 100755 .github/CODEOWNERS mode change 100644 => 100755 .github/workflows/release.yml mode change 100644 => 100755 .gitignore mode change 100644 => 100755 .gitmodules mode change 100644 => 100755 .python-version mode change 100644 => 100755 LICENSE mode change 100644 => 100755 MODEL_ACCEPTABLE_USE_POLICY.md mode change 100644 => 100755 MODEL_LICENSE.md mode change 100644 => 100755 README.md mode change 100644 => 100755 assets/generalization_task.png mode change 100644 => 100755 examples/fewshot.toml mode change 100644 => 100755 examples/mixed.toml mode change 100644 => 100755 examples/random.h5ad mode change 100644 => 100755 examples/zeroshot.toml mode change 100644 => 100755 pyproject.toml mode change 100644 => 100755 ruff.toml mode change 100644 => 100755 scripts/state_embed_anndata.py mode change 100644 => 100755 singularity.def mode change 100644 => 100755 src/state/__init__.py mode change 100644 => 100755 src/state/__main__.py mode change 100644 => 100755 src/state/_cli/__init__.py mode change 100644 => 100755 src/state/_cli/_emb/__init__.py mode change 100644 => 100755 src/state/_cli/_emb/_eval.py mode change 100644 => 100755 src/state/_cli/_emb/_fit.py mode change 100644 => 100755 src/state/_cli/_emb/_preprocess.py mode change 100644 => 100755 src/state/_cli/_emb/_query.py mode change 100644 => 100755 src/state/_cli/_emb/_transform.py mode change 100644 => 100755 src/state/_cli/_tx/__init__.py mode change 100644 => 100755 src/state/_cli/_tx/_heatmap.py mode change 100644 => 100755 src/state/_cli/_tx/_infer.py mode change 100644 => 100755 src/state/_cli/_tx/_predict.py mode change 100644 => 100755 src/state/_cli/_tx/_preprocess_infer.py mode change 100644 => 100755 src/state/_cli/_tx/_preprocess_train.py mode change 100644 => 100755 src/state/_cli/_tx/_train.py mode change 100644 => 100755 src/state/configs/__init__.py mode change 100644 => 100755 src/state/configs/config.yaml mode change 100644 => 100755 src/state/configs/data/default.yaml mode change 100644 => 100755 src/state/configs/data/perturbation.yaml mode change 100644 => 100755 src/state/configs/model/celltypemean.yaml mode change 100644 => 100755 src/state/configs/model/context_mean.yaml mode change 100644 => 100755 src/state/configs/model/cpa.yaml mode change 100644 => 100755 src/state/configs/model/decoder_only.yaml mode change 100644 => 100755 src/state/configs/model/embedsum.yaml mode change 100644 => 100755 src/state/configs/model/globalsimplesum.yaml mode change 100644 => 100755 src/state/configs/model/old_neuralot.yaml mode change 100644 => 100755 src/state/configs/model/pertsets.yaml mode change 100644 => 100755 src/state/configs/model/perturb_mean.yaml mode change 100644 => 100755 src/state/configs/model/scgpt-chemical.yaml mode change 100644 => 100755 src/state/configs/model/scgpt-genetic.yaml mode change 100644 => 100755 src/state/configs/model/scvi.yaml mode change 100644 => 100755 src/state/configs/model/state.yaml mode change 100644 => 100755 src/state/configs/model/state_lg.yaml mode change 100644 => 100755 src/state/configs/model/state_sm.yaml mode change 100644 => 100755 src/state/configs/model/tahoe_best.yaml mode change 100644 => 100755 src/state/configs/model/tahoe_llama_212693232.yaml mode change 100644 => 100755 src/state/configs/model/tahoe_llama_62089464.yaml mode change 100644 => 100755 src/state/configs/state-defaults.yaml mode change 100644 => 100755 src/state/configs/training/cpa.yaml mode change 100644 => 100755 src/state/configs/training/default.yaml mode change 100644 => 100755 src/state/configs/training/scgpt.yaml mode change 100644 => 100755 src/state/configs/training/scvi.yaml mode change 100644 => 100755 src/state/configs/wandb/default.yaml mode change 100644 => 100755 src/state/emb/__init__.py mode change 100644 => 100755 src/state/emb/data/__init__.py mode change 100644 => 100755 src/state/emb/data/loader.py mode change 100644 => 100755 src/state/emb/eval/__init__.py mode change 100644 => 100755 src/state/emb/eval/emb.py mode change 100644 => 100755 src/state/emb/finetune_decoder.py mode change 100644 => 100755 src/state/emb/inference.py mode change 100644 => 100755 src/state/emb/nn/__init__.py mode change 100644 => 100755 src/state/emb/nn/eval_utils.py mode change 100644 => 100755 src/state/emb/nn/flash_transformer.py mode change 100644 => 100755 src/state/emb/nn/loss.py mode change 100644 => 100755 src/state/emb/nn/model.py mode change 100644 => 100755 src/state/emb/tools/__init__.py mode change 100644 => 100755 src/state/emb/tools/slurm.py mode change 100644 => 100755 src/state/emb/train/__init__.py mode change 100644 => 100755 src/state/emb/train/__main__.py mode change 100644 => 100755 src/state/emb/train/callbacks.py mode change 100644 => 100755 src/state/emb/train/trainer.py mode change 100644 => 100755 src/state/emb/utils.py mode change 100644 => 100755 src/state/emb/vectordb.py mode change 100644 => 100755 src/state/py.typed mode change 100644 => 100755 src/state/tx/__init__.py mode change 100644 => 100755 src/state/tx/callbacks/__init__.py mode change 100644 => 100755 src/state/tx/callbacks/batch_speed_monitor.py mode change 100644 => 100755 src/state/tx/callbacks/cumulative_flops.py mode change 100644 => 100755 src/state/tx/callbacks/model_flops_utilization.py mode change 100644 => 100755 src/state/tx/data/dataset/__init__.py mode change 100644 => 100755 src/state/tx/data/dataset/scgpt_perturbation_dataset.py mode change 100644 => 100755 src/state/tx/models/__init__.py mode change 100644 => 100755 src/state/tx/models/base.py mode change 100644 => 100755 src/state/tx/models/context_mean.py mode change 100644 => 100755 src/state/tx/models/cpa/__init__.py mode change 100644 => 100755 src/state/tx/models/cpa/_base_modules.py mode change 100644 => 100755 src/state/tx/models/cpa/_callbacks.py mode change 100644 => 100755 src/state/tx/models/cpa/_dists.py mode change 100644 => 100755 src/state/tx/models/cpa/_model.py mode change 100644 => 100755 src/state/tx/models/cpa/_module.py mode change 100644 => 100755 src/state/tx/models/cpa/_task.py mode change 100644 => 100755 src/state/tx/models/decoder_only.py mode change 100644 => 100755 src/state/tx/models/decoders.py mode change 100644 => 100755 src/state/tx/models/decoders_nb.py mode change 100644 => 100755 src/state/tx/models/embed_sum.py mode change 100644 => 100755 src/state/tx/models/old_neural_ot.py mode change 100644 => 100755 src/state/tx/models/perturb_mean.py mode change 100644 => 100755 src/state/tx/models/pseudobulk.py mode change 100644 => 100755 src/state/tx/models/scgpt/__init__.py mode change 100644 => 100755 src/state/tx/models/scgpt/dsbn.py mode change 100644 => 100755 src/state/tx/models/scgpt/gene_tokenizer.py mode change 100644 => 100755 src/state/tx/models/scgpt/generation_model.py mode change 100644 => 100755 src/state/tx/models/scgpt/grad_reverse.py mode change 100644 => 100755 src/state/tx/models/scgpt/lightning_model.py mode change 100644 => 100755 src/state/tx/models/scgpt/loss.py mode change 100644 => 100755 src/state/tx/models/scgpt/model.py mode change 100644 => 100755 src/state/tx/models/scgpt/utils.py mode change 100644 => 100755 src/state/tx/models/scvi/__init__.py mode change 100644 => 100755 src/state/tx/models/scvi/_base_modules.py mode change 100644 => 100755 src/state/tx/models/scvi/_callbacks.py mode change 100644 => 100755 src/state/tx/models/scvi/_dists.py mode change 100644 => 100755 src/state/tx/models/scvi/_model.py mode change 100644 => 100755 src/state/tx/models/scvi/_module.py mode change 100644 => 100755 src/state/tx/models/scvi/_task.py mode change 100644 => 100755 src/state/tx/models/state_transition.py mode change 100644 => 100755 src/state/tx/models/utils.py mode change 100644 => 100755 src/state/tx/utils/__init__.py mode change 100644 => 100755 src/state/tx/utils/singleton.py mode change 100644 => 100755 tests/test_callbacks.py diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS old mode 100644 new mode 100755 diff --git a/.github/workflows/release.yml b/.github/workflows/release.yml old mode 100644 new mode 100755 diff --git a/.gitignore b/.gitignore old mode 100644 new mode 100755 diff --git a/.gitmodules b/.gitmodules old mode 100644 new mode 100755 diff --git a/.python-version b/.python-version old mode 100644 new mode 100755 diff --git a/LICENSE b/LICENSE old mode 100644 new mode 100755 diff --git a/MODEL_ACCEPTABLE_USE_POLICY.md b/MODEL_ACCEPTABLE_USE_POLICY.md old mode 100644 new mode 100755 diff --git a/MODEL_LICENSE.md b/MODEL_LICENSE.md old mode 100644 new mode 100755 diff --git a/README.md b/README.md old mode 100644 new mode 100755 diff --git a/assets/generalization_task.png b/assets/generalization_task.png old mode 100644 new mode 100755 diff --git a/examples/fewshot.toml b/examples/fewshot.toml old mode 100644 new mode 100755 diff --git a/examples/mixed.toml b/examples/mixed.toml old mode 100644 new mode 100755 diff --git a/examples/random.h5ad b/examples/random.h5ad old mode 100644 new mode 100755 diff --git a/examples/zeroshot.toml b/examples/zeroshot.toml old mode 100644 new mode 100755 diff --git a/pyproject.toml b/pyproject.toml old mode 100644 new mode 100755 diff --git a/ruff.toml b/ruff.toml old mode 100644 new mode 100755 diff --git a/scripts/state_embed_anndata.py b/scripts/state_embed_anndata.py old mode 100644 new mode 100755 diff --git a/singularity.def b/singularity.def old mode 100644 new mode 100755 diff --git a/src/state/__init__.py b/src/state/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/__main__.py b/src/state/__main__.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/__init__.py b/src/state/_cli/_emb/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/_eval.py b/src/state/_cli/_emb/_eval.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/_fit.py b/src/state/_cli/_emb/_fit.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/_preprocess.py b/src/state/_cli/_emb/_preprocess.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/_query.py b/src/state/_cli/_emb/_query.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_emb/_transform.py b/src/state/_cli/_emb/_transform.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/_heatmap.py b/src/state/_cli/_tx/_heatmap.py old mode 100644 new mode 100755 index bf9fb1d9..0d8253b4 --- a/src/state/_cli/_tx/_heatmap.py +++ b/src/state/_cli/_tx/_heatmap.py @@ -87,14 +87,17 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): parser.add_argument( "--annotation-path", type=str, - default="/home/dhruvgautam/gene_annotations_1_2000.pkl", + default="/home/dhruvgautam/annotations/replogle_go_annotations.pkl", #/home/dhruvgautam/annotations/var_dims_gene_go_annotations.json help="Path to the hvg gene annotations file.", ) parser.add_argument( "--annotation-field", type=str, default="go_cc_paths", - help="Field name in the annotation data to use for pathway grouping (e.g., 'go_cc_paths', 'go_mf_paths', 'go_bp_paths', etc.).", + help=( + "Field name in structured annotation data to use for pathway grouping (e.g., 'go_cc_paths'). " + "Ignored when loading JSON files that map genes directly to pathways." + ), ) @@ -175,6 +178,23 @@ def load_config(cfg_path: str) -> dict: # 2. Find run output directory & load data module run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + if not os.path.isabs(run_output_dir): + run_output_dir = os.path.abspath(run_output_dir) + + if not os.path.exists(run_output_dir): + inferred_run_dir = args.output_dir + if os.path.exists(inferred_run_dir): + logger.warning( + "Run directory %s not found; falling back to config directory %s", + run_output_dir, + inferred_run_dir, + ) + run_output_dir = inferred_run_dir + else: + raise FileNotFoundError( + "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) + ) + data_module_path = os.path.join(run_output_dir, "data_module.torch") if not os.path.exists(data_module_path): raise FileNotFoundError(f"Could not find data module at {data_module_path}?") @@ -622,34 +642,163 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): # Phase 2: Run inference with GO MF pathway groups upregulated (only if requested) if args.test_time_heat_map: logger.info("Phase 2: Loading GO MF pathway annotations and running pathway-based upregulation...") - + + if args.results_dir is not None: + results_dir = args.results_dir + else: + results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + os.makedirs(results_dir, exist_ok=True) + annotation_ext = os.path.splitext(args.annotation_path)[1].lower() + annotation_source_type = "unknown" + annotation_label = args.annotation_field + field_suffix = ( + (annotation_label or "pathways").replace('_', '').lower() + if (annotation_label or "").strip() + else "pathways" + ) + # Load gene annotations import pickle - with open(args.annotation_path, 'rb') as f: - gene_annotations = pickle.load(f) - - # Group genes by pathways using the specified annotation field from collections import defaultdict + pathway_to_genes = defaultdict(list) - - for idx, data in gene_annotations.items(): - pathway_data = data[args.annotation_field] - if pathway_data: # If gene has pathways - pathways = pathway_data.split(';') + gene_names = var_dims.get("gene_names") + gene_name_to_index = {str(name): idx for idx, name in enumerate(gene_names)} if gene_names is not None else {} + + if annotation_ext == ".json": + annotation_source_type = "json" + annotation_label = os.path.splitext(os.path.basename(args.annotation_path))[0] + field_suffix = annotation_label.replace('_', '').lower() or "pathways" + + with open(args.annotation_path, 'r') as f: + gene_annotations = json.load(f) + + if not isinstance(gene_annotations, dict): + raise ValueError( + f"Expected JSON annotation file {args.annotation_path} to map gene names to pathway collections." + ) + + missing_genes = set() + for gene_name, pathway_data in gene_annotations.items(): + if not pathway_data: + continue + + idx = gene_name_to_index.get(str(gene_name)) + if idx is None: + missing_genes.add(str(gene_name)) + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [] + for entry in pathway_data: + if entry is None: + continue + pathways.append(str(entry).strip()) + pathways = [p for p in pathways if p] + + for pathway in pathways: + pathway_to_genes[pathway].append(idx) + + if missing_genes: + sample_missing = ", ".join(sorted(missing_genes)[:5]) + logger.warning( + "Skipped %d gene(s) from annotation file not present in model gene names (e.g., %s)", + len(missing_genes), + sample_missing, + ) + elif annotation_ext in {".pkl", ".pickle"}: + annotation_source_type = "pickle" + field_suffix = ( + (args.annotation_field or "pathways").replace('_', '').lower() + if (args.annotation_field or "").strip() + else "pathways" + ) + + with open(args.annotation_path, 'rb') as f: + gene_annotations = pickle.load(f) + + if not args.annotation_field: + raise ValueError( + "--annotation-field must be provided when loading pickle annotation files." + ) + + for idx, data in gene_annotations.items(): + pathway_data = None + if isinstance(data, dict): + pathway_data = data.get(args.annotation_field) + else: + try: + pathway_data = data[args.annotation_field] + except (KeyError, TypeError): + pathway_data = getattr(data, args.annotation_field, None) + + if not pathway_data: + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [str(p).strip() for p in pathway_data if str(p).strip()] + + try: + gene_index = int(idx) - 1 + except (TypeError, ValueError): + gene_index = gene_name_to_index.get(str(idx)) + + if gene_index is None or gene_index < 0: + continue + for pathway in pathways: - # Convert 1-indexed to 0-indexed - pathway_to_genes[pathway].append(idx - 1) + pathway_to_genes[pathway].append(gene_index) + else: + raise ValueError( + f"Unsupported annotation file extension '{annotation_ext}' for {args.annotation_path}." + ) # Filter out pathways with too few genes (less than 3) to avoid noise filtered_pathways = {pathway: genes for pathway, genes in pathway_to_genes.items() if len(genes) >= 3} - logger.info(f"Found {len(pathway_to_genes)} total pathways from field '{args.annotation_field}'") + logger.info( + "Found %d total pathways from annotation source '%s' (%s)", + len(pathway_to_genes), + annotation_label, + annotation_source_type, + ) logger.info(f"Using {len(filtered_pathways)} pathways with 3+ genes for upregulation") # Initialize heatmap array: [num_pathways, num_perturbations] num_pathways = len(filtered_pathways) heatmap_distances = np.zeros((num_pathways, len(perts_order)), dtype=np.float32) pathway_names = list(filtered_pathways.keys()) + + annotation_label_pretty = (annotation_label or "Annotation").replace('_', ' ').strip() + if annotation_label_pretty: + annotation_label_pretty = annotation_label_pretty.title() + else: + annotation_label_pretty = "Annotation" + + upregulated_preds_path = None + upregulated_preds_memmap = None + if num_pathways == 0: + logger.warning("No pathways passed filtering; skipping upregulated prediction storage.") + else: + try: + upregulated_preds_path = os.path.join( + results_dir, + f"{field_suffix}_pathway_upregulated_preds.npy", + ) + upregulated_preds_memmap = np.memmap( + upregulated_preds_path, + dtype=np.float32, + mode="w+", + shape=(num_pathways, len(perts_order), target_core_n, output_dim), + ) + except Exception as e: + logger.warning("Failed to initialize storage for upregulated predictions: %s", e) + upregulated_preds_path = None + upregulated_preds_memmap = None # Create a copy of core_cells for upregulation experiments original_core_cells = {} @@ -742,6 +891,9 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ batch_preds = model.predict_step(batch, p_idx, padded=False) upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + if upregulated_preds_memmap is not None: + upregulated_preds_memmap[pathway_idx, p_idx, :, :] = upregulated_preds + # Compute euclidean distance between normal and upregulated predictions normal_preds = normal_preds_per_pert[pert] # [64, output_dim] distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 64 cells @@ -754,20 +906,18 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ else: core_cells[k] = v.copy() if hasattr(v, 'copy') else v - logger.info(f"Phase 2 complete: Upregulated inference for all pathways from field '{args.annotation_field}'.") + logger.info( + "Phase 2 complete: Upregulated inference for all pathways from annotation source '%s'.", + annotation_label or annotation_source_type, + ) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap.flush() + logger.info(f"Saved upregulated prediction tensors to {upregulated_preds_path}") # Create filename based on annotation field - field_suffix = args.annotation_field.replace('_', '').lower() - # Save heatmap data try: - # Determine results directory - if args.results_dir is not None: - results_dir = args.results_dir - else: - results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) - os.makedirs(results_dir, exist_ok=True) - heatmap_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") np.save(heatmap_path, heatmap_distances) @@ -786,18 +936,26 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ # Save metadata for the heatmap heatmap_meta = { "shape": [num_pathways, len(perts_order)], - "description": f"Euclidean distance heatmap: rows={args.annotation_field} pathways, cols=perturbations", + "description": ( + f"Euclidean distance heatmap: rows={annotation_label_pretty} pathways, cols=perturbations" + ), "perturbations": perts_order, "pathway_names": pathway_names, "distance_type": "mean_euclidean_norm_across_64_cells", "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene", - "annotation_field": args.annotation_field + "annotation_field": annotation_label if annotation_source_type != "json" else None, + "annotation_source_type": annotation_source_type, + "upregulated_preds_path": upregulated_preds_path, } heatmap_meta_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") with open(heatmap_meta_path, "w") as f: json.dump(heatmap_meta, f, indent=2) - logger.info(f"Saved {args.annotation_field} pathway upregulation heatmap to {heatmap_path}") + logger.info( + "Saved %s pathway upregulation heatmap to %s", + annotation_label_pretty, + heatmap_path, + ) logger.info(f"Heatmap shape: {heatmap_distances.shape} (pathways x perturbations)") except Exception as e: logger.warning(f"Failed to save heatmap data: {e}") diff --git a/src/state/_cli/_tx/_infer.py b/src/state/_cli/_tx/_infer.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/_predict.py b/src/state/_cli/_tx/_predict.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/_preprocess_infer.py b/src/state/_cli/_tx/_preprocess_infer.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/_preprocess_train.py b/src/state/_cli/_tx/_preprocess_train.py old mode 100644 new mode 100755 diff --git a/src/state/_cli/_tx/_train.py b/src/state/_cli/_tx/_train.py old mode 100644 new mode 100755 diff --git a/src/state/configs/__init__.py b/src/state/configs/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/configs/config.yaml b/src/state/configs/config.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/data/default.yaml b/src/state/configs/data/default.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/data/perturbation.yaml b/src/state/configs/data/perturbation.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/celltypemean.yaml b/src/state/configs/model/celltypemean.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/context_mean.yaml b/src/state/configs/model/context_mean.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/cpa.yaml b/src/state/configs/model/cpa.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/decoder_only.yaml b/src/state/configs/model/decoder_only.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/embedsum.yaml b/src/state/configs/model/embedsum.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/globalsimplesum.yaml b/src/state/configs/model/globalsimplesum.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/old_neuralot.yaml b/src/state/configs/model/old_neuralot.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/pertsets.yaml b/src/state/configs/model/pertsets.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/perturb_mean.yaml b/src/state/configs/model/perturb_mean.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/scgpt-chemical.yaml b/src/state/configs/model/scgpt-chemical.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/scgpt-genetic.yaml b/src/state/configs/model/scgpt-genetic.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/scvi.yaml b/src/state/configs/model/scvi.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/state_lg.yaml b/src/state/configs/model/state_lg.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/state_sm.yaml b/src/state/configs/model/state_sm.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/tahoe_best.yaml b/src/state/configs/model/tahoe_best.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/tahoe_llama_212693232.yaml b/src/state/configs/model/tahoe_llama_212693232.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/model/tahoe_llama_62089464.yaml b/src/state/configs/model/tahoe_llama_62089464.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/state-defaults.yaml b/src/state/configs/state-defaults.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/training/cpa.yaml b/src/state/configs/training/cpa.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/training/default.yaml b/src/state/configs/training/default.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/training/scgpt.yaml b/src/state/configs/training/scgpt.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/training/scvi.yaml b/src/state/configs/training/scvi.yaml old mode 100644 new mode 100755 diff --git a/src/state/configs/wandb/default.yaml b/src/state/configs/wandb/default.yaml old mode 100644 new mode 100755 diff --git a/src/state/emb/__init__.py b/src/state/emb/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/data/__init__.py b/src/state/emb/data/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/data/loader.py b/src/state/emb/data/loader.py old mode 100644 new mode 100755 diff --git a/src/state/emb/eval/__init__.py b/src/state/emb/eval/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/eval/emb.py b/src/state/emb/eval/emb.py old mode 100644 new mode 100755 diff --git a/src/state/emb/finetune_decoder.py b/src/state/emb/finetune_decoder.py old mode 100644 new mode 100755 diff --git a/src/state/emb/inference.py b/src/state/emb/inference.py old mode 100644 new mode 100755 diff --git a/src/state/emb/nn/__init__.py b/src/state/emb/nn/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/nn/eval_utils.py b/src/state/emb/nn/eval_utils.py old mode 100644 new mode 100755 diff --git a/src/state/emb/nn/flash_transformer.py b/src/state/emb/nn/flash_transformer.py old mode 100644 new mode 100755 diff --git a/src/state/emb/nn/loss.py b/src/state/emb/nn/loss.py old mode 100644 new mode 100755 diff --git a/src/state/emb/nn/model.py b/src/state/emb/nn/model.py old mode 100644 new mode 100755 diff --git a/src/state/emb/tools/__init__.py b/src/state/emb/tools/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/tools/slurm.py b/src/state/emb/tools/slurm.py old mode 100644 new mode 100755 diff --git a/src/state/emb/train/__init__.py b/src/state/emb/train/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/train/__main__.py b/src/state/emb/train/__main__.py old mode 100644 new mode 100755 diff --git a/src/state/emb/train/callbacks.py b/src/state/emb/train/callbacks.py old mode 100644 new mode 100755 diff --git a/src/state/emb/train/trainer.py b/src/state/emb/train/trainer.py old mode 100644 new mode 100755 diff --git a/src/state/emb/utils.py b/src/state/emb/utils.py old mode 100644 new mode 100755 diff --git a/src/state/emb/vectordb.py b/src/state/emb/vectordb.py old mode 100644 new mode 100755 diff --git a/src/state/py.typed b/src/state/py.typed old mode 100644 new mode 100755 diff --git a/src/state/tx/__init__.py b/src/state/tx/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/callbacks/__init__.py b/src/state/tx/callbacks/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/callbacks/batch_speed_monitor.py b/src/state/tx/callbacks/batch_speed_monitor.py old mode 100644 new mode 100755 diff --git a/src/state/tx/callbacks/cumulative_flops.py b/src/state/tx/callbacks/cumulative_flops.py old mode 100644 new mode 100755 diff --git a/src/state/tx/callbacks/model_flops_utilization.py b/src/state/tx/callbacks/model_flops_utilization.py old mode 100644 new mode 100755 diff --git a/src/state/tx/data/dataset/__init__.py b/src/state/tx/data/dataset/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/data/dataset/scgpt_perturbation_dataset.py b/src/state/tx/data/dataset/scgpt_perturbation_dataset.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/__init__.py b/src/state/tx/models/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/base.py b/src/state/tx/models/base.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/context_mean.py b/src/state/tx/models/context_mean.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/__init__.py b/src/state/tx/models/cpa/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_base_modules.py b/src/state/tx/models/cpa/_base_modules.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_callbacks.py b/src/state/tx/models/cpa/_callbacks.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_dists.py b/src/state/tx/models/cpa/_dists.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_model.py b/src/state/tx/models/cpa/_model.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_module.py b/src/state/tx/models/cpa/_module.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/cpa/_task.py b/src/state/tx/models/cpa/_task.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/decoder_only.py b/src/state/tx/models/decoder_only.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/decoders.py b/src/state/tx/models/decoders.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/decoders_nb.py b/src/state/tx/models/decoders_nb.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/embed_sum.py b/src/state/tx/models/embed_sum.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/old_neural_ot.py b/src/state/tx/models/old_neural_ot.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/perturb_mean.py b/src/state/tx/models/perturb_mean.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/pseudobulk.py b/src/state/tx/models/pseudobulk.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/__init__.py b/src/state/tx/models/scgpt/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/dsbn.py b/src/state/tx/models/scgpt/dsbn.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/gene_tokenizer.py b/src/state/tx/models/scgpt/gene_tokenizer.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/generation_model.py b/src/state/tx/models/scgpt/generation_model.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/grad_reverse.py b/src/state/tx/models/scgpt/grad_reverse.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/lightning_model.py b/src/state/tx/models/scgpt/lightning_model.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/loss.py b/src/state/tx/models/scgpt/loss.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/model.py b/src/state/tx/models/scgpt/model.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scgpt/utils.py b/src/state/tx/models/scgpt/utils.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/__init__.py b/src/state/tx/models/scvi/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_base_modules.py b/src/state/tx/models/scvi/_base_modules.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_callbacks.py b/src/state/tx/models/scvi/_callbacks.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_dists.py b/src/state/tx/models/scvi/_dists.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_model.py b/src/state/tx/models/scvi/_model.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_module.py b/src/state/tx/models/scvi/_module.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/scvi/_task.py b/src/state/tx/models/scvi/_task.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py old mode 100644 new mode 100755 diff --git a/src/state/tx/models/utils.py b/src/state/tx/models/utils.py old mode 100644 new mode 100755 diff --git a/src/state/tx/utils/__init__.py b/src/state/tx/utils/__init__.py old mode 100644 new mode 100755 diff --git a/src/state/tx/utils/singleton.py b/src/state/tx/utils/singleton.py old mode 100644 new mode 100755 diff --git a/tests/test_callbacks.py b/tests/test_callbacks.py old mode 100644 new mode 100755 From 3197e24a0905b0f11f56ebb159e2b130d2b042a3 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Wed, 1 Oct 2025 06:44:03 +0000 Subject: [PATCH 5/9] working --- README.md | 19 + pyproject.toml | 2 +- src/state/__main__.py | 4 + src/state/_cli/__init__.py | 2 + src/state/_cli/_tx/__init__.py | 3 + src/state/_cli/_tx/_double.py | 794 ++++++++++++++ src/state/_cli/_tx/_heatmap.py | 521 ++++++---- src/state/_cli/_tx/_heatmap_og.py | 1301 +++++++++++++++++++++++ src/state/_cli/_tx/_heatmap_train.py | 1424 ++++++++++++++++++++++++++ src/state/_cli/_tx/test.py | 42 + src/state/_cli/_tx/test_gpt.py | 28 + 11 files changed, 3942 insertions(+), 198 deletions(-) create mode 100644 src/state/_cli/_tx/_double.py create mode 100644 src/state/_cli/_tx/_heatmap_og.py create mode 100644 src/state/_cli/_tx/_heatmap_train.py create mode 100644 src/state/_cli/_tx/test.py create mode 100644 src/state/_cli/_tx/test_gpt.py diff --git a/README.md b/README.md index 85a5528b..5b0eaa43 100755 --- a/README.md +++ b/README.md @@ -59,6 +59,25 @@ options: -h, --help show this help message and exit ``` +### Double Perturbation Analysis + +Generate double perturbation sweeps against a baseline core cell batch: + +```bash +uv run state tx double \ + --output-dir /path/to/training/run \ + --target-cell-type RPE1 \ + --checkpoint last.ckpt +``` + +Key arguments: +- `--target-cell-type`: cell type used to seed the core control cells +- `--checkpoint`: checkpoint filename inside `/checkpoints` +- `--results-dir`: optional override for where to dump results; defaults to `/eval_` +- `--phase-one-only`: stop after saving single-perturbation predictions + +The command emits `.npy` snapshots, AnnData files, and metric reports mirroring the single-perturbation heatmap command. + ## State Transition Model (ST) To start an experiment, write a TOML file (see `examples/zeroshot.toml` or diff --git a/pyproject.toml b/pyproject.toml index e6eb7541..5069594b 100755 --- a/pyproject.toml +++ b/pyproject.toml @@ -13,7 +13,7 @@ authors = [ requires-python = ">=3.10,<3.13" dependencies = [ "anndata>=0.11.4", - "cell-load>=0.8.3", + "cell-load>=0.8.5", "numpy>=2.2.6", "pandas>=2.2.3", "pyyaml>=6.0.2", diff --git a/src/state/__main__.py b/src/state/__main__.py index 45a26993..532320b9 100755 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -11,6 +11,7 @@ run_emb_query, run_emb_preprocess, run_emb_eval, + run_tx_double, run_tx_heatmap, run_tx_infer, run_tx_predict, @@ -125,6 +126,9 @@ def main(): case "heatmap": # Run heatmap analysis using argparse run_tx_heatmap(args) + case "double": + # Run double perturbation analysis using argparse + run_tx_double(args) case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index da4fc456..d1eb55dd 100755 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -1,6 +1,7 @@ from ._emb import add_arguments_emb, run_emb_fit, run_emb_transform, run_emb_query, run_emb_preprocess, run_emb_eval from ._tx import ( add_arguments_tx, + run_tx_double, run_tx_heatmap, run_tx_infer, run_tx_predict, @@ -14,6 +15,7 @@ "add_arguments_tx", "run_tx_train", "run_tx_predict", + "run_tx_double", "run_tx_heatmap", "run_tx_infer", "run_tx_preprocess_train", diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index e59f2c70..3e76062f 100755 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -1,5 +1,6 @@ import argparse as ap +from ._double import add_arguments_double, run_tx_double from ._heatmap import add_arguments_heatmap, run_tx_heatmap from ._infer import add_arguments_infer, run_tx_infer from ._predict import add_arguments_predict, run_tx_predict @@ -11,6 +12,7 @@ "run_tx_train", "run_tx_predict", "run_tx_heatmap", + "run_tx_double", "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", @@ -24,6 +26,7 @@ def add_arguments_tx(parser: ap.ArgumentParser): add_arguments_train(subparsers.add_parser("train", add_help=False)) add_arguments_predict(subparsers.add_parser("predict")) add_arguments_heatmap(subparsers.add_parser("heatmap")) + add_arguments_double(subparsers.add_parser("double")) add_arguments_infer(subparsers.add_parser("infer")) add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) diff --git a/src/state/_cli/_tx/_double.py b/src/state/_cli/_tx/_double.py new file mode 100644 index 00000000..cf934cd8 --- /dev/null +++ b/src/state/_cli/_tx/_double.py @@ -0,0 +1,794 @@ +import argparse as ap + + +def add_arguments_double(parser: ap.ArgumentParser) -> None: + """CLI for double perturbation analysis on a target cell line.""" + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help=( + "Path to the output_dir containing the config.yaml file that was saved during training." + ), + ) + parser.add_argument( + "--checkpoint", + type=str, + default="last.ckpt", + help="Checkpoint filename relative to the output directory (default: last.ckpt).", + ) + parser.add_argument( + "--test-time-finetune", + type=int, + default=0, + help="If >0, run test-time fine-tuning for the specified number of epochs on control cells only.", + ) + parser.add_argument( + "--profile", + type=str, + default="full", + choices=["full", "minimal", "de", "anndata"], + help="Evaluation profile to run after inference.", + ) + parser.add_argument( + "--predict-only", + action="store_true", + help="Skip metric computation and only run inference.", + ) + parser.add_argument( + "--shared-only", + action="store_true", + help="Restrict outputs to perturbations present in both train and test sets.", + ) + parser.add_argument( + "--eval-train-data", + action="store_true", + help="Evaluate the model on the training data instead of the test data.", + ) + parser.add_argument( + "--target-cell-type", + type=str, + required=True, + help="Cell type to construct the base core cells for double perturbations.", + ) + parser.add_argument( + "--results-dir", + type=str, + default=None, + help="Directory to save results. Defaults to /eval_.", + ) + + +def run_tx_double(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> None: + import logging + import os + import sys + import copy + + import anndata + import lightning.pytorch as pl + import numpy as np + import pandas as pd + import torch + import yaml + from tqdm import tqdm + + from cell_eval import MetricsEvaluator + from cell_eval.utils import split_anndata_on_celltype + from cell_load.data_modules import PerturbationDataModule + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + def _prepare_for_serialization(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().copy() + if isinstance(obj, dict): + return {k: _prepare_for_serialization(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_serialization(v) for v in obj] + return obj + + def _save_numpy_snapshot(obj, path, description=None): + serializable = _prepare_for_serialization(obj) + try: + np.save(path, serializable, allow_pickle=True) + if description: + logger.info("Saved %s to %s", description, path) + else: + logger.info("Saved snapshot to %s", path) + except Exception as exc: + logger.warning("Failed to save %s to %s: %s", description or "snapshot", path, exc) + + def _clone_core_cells(src): + cloned = {} + for key, value in src.items(): + if isinstance(value, torch.Tensor): + cloned[key] = value.clone() + else: + try: + cloned[key] = copy.deepcopy(value) + except Exception: + cloned[key] = value + return cloned + + def _to_list(value): + if isinstance(value, list): + return value + if isinstance(value, torch.Tensor): + try: + return [x.item() if x.dim() == 0 else x for x in value] + except Exception: + return value.tolist() + if isinstance(value, (tuple, set)): + return list(value) + if value is None: + return [] + return [value] + + def _normalize_field(values, length, filler=None): + items = list(_to_list(values)) + if len(items) == 1 and length > 1: + items = items * length + if len(items) < length: + items.extend([filler] * (length - len(items))) + elif len(items) > length: + items = items[:length] + return items + + def _resolve_celltype_key(batch, module): + candidate_keys = [] + base_key = getattr(module, "cell_type_key", None) + if base_key: + candidate_keys.append(base_key) + alias_keys = getattr(module, "cell_type_key_aliases", None) + if isinstance(alias_keys, (list, tuple)): + candidate_keys.extend(alias_keys) + alias_keys_alt = getattr(module, "celltype_key_aliases", None) + if isinstance(alias_keys_alt, (list, tuple)): + candidate_keys.extend(alias_keys_alt) + candidate_keys.extend([ + "celltype_name", + "cell_type", + "celltype", + "cell_line", + ]) + seen = set() + ordered_candidates = [] + for key in candidate_keys: + if not key or key in seen: + continue + seen.add(key) + ordered_candidates.append(key) + if key in batch: + return key, ordered_candidates + return None, ordered_candidates + + torch.multiprocessing.set_sharing_strategy("file_system") + + config_path = os.path.join(args.output_dir, "config.yaml") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Could not find config file: {config_path}") + with open(config_path, "r", encoding="utf-8") as file: + cfg = yaml.safe_load(file) + logger.info("Loaded config from %s", config_path) + + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + if not os.path.isabs(run_output_dir): + run_output_dir = os.path.abspath(run_output_dir) + if not os.path.exists(run_output_dir): + inferred_run_dir = args.output_dir + if os.path.exists(inferred_run_dir): + logger.warning( + "Run directory %s not found; falling back to config directory %s", + run_output_dir, + inferred_run_dir, + ) + run_output_dir = inferred_run_dir + else: + raise FileNotFoundError( + "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) + ) + + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + + pl.seed_everything(cfg["training"]["train_seed"]) + + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + checkpoint_path = os.path.join(checkpoint_dir, args.checkpoint) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Could not find checkpoint at {checkpoint_path}. Specify --checkpoint with a valid file." + ) + logger.info("Loading model from %s", checkpoint_path) + + model_name = cfg["model"]["name"] + model_kwargs = cfg["model"]["kwargs"] + var_dims = data_module.get_var_dims() + + if model_name.lower() == "embedsum": + from ...tx.models.embed_sum import EmbedSumPerturbationModel + + ModelClass = EmbedSumPerturbationModel + elif model_name.lower() == "old_neuralot": + from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel + + ModelClass = OldNeuralOTPerturbationModel + elif model_name.lower() in {"neuralot", "pertsets", "state"}: + from ...tx.models.state_transition import StateTransitionPerturbationModel + + ModelClass = StateTransitionPerturbationModel + elif model_name.lower() in {"globalsimplesum", "perturb_mean"}: + from ...tx.models.perturb_mean import PerturbMeanPerturbationModel + + ModelClass = PerturbMeanPerturbationModel + elif model_name.lower() in {"celltypemean", "context_mean"}: + from ...tx.models.context_mean import ContextMeanPerturbationModel + + ModelClass = ContextMeanPerturbationModel + elif model_name.lower() == "decoder_only": + from ...tx.models.decoder_only import DecoderOnlyPerturbationModel + + ModelClass = DecoderOnlyPerturbationModel + else: + raise ValueError(f"Unknown model class: {model_name}") + + model = ModelClass.load_from_checkpoint( + checkpoint_path, + input_dim=var_dims["input_dim"], + hidden_dim=model_kwargs["hidden_dim"], + gene_dim=var_dims.get("gene_dim"), + hvg_dim=var_dims.get("hvg_dim"), + output_dim=var_dims["output_dim"], + pert_dim=var_dims["pert_dim"], + **model_kwargs, + ) + model.eval() + logger.info("Model loaded successfully.") + + results_dir_default = ( + args.results_dir + if args.results_dir is not None + else os.path.join(args.output_dir, f"eval_{os.path.basename(args.checkpoint)}") + ) + + data_module.batch_size = 1 + target_celltype = getattr(args, "target_cell_type") + + def _create_filtered_loader(module): + base_loader = ( + module.train_dataloader(test=True) + if args.eval_train_data + else module.test_dataloader() + ) + + celltype_key, attempted = _resolve_celltype_key({}, module) + + def _generator(): + found_target = False + for batch in base_loader: + if target_celltype is None: + found_target = True + yield batch + continue + + key = celltype_key + if key is None: + key, attempted_keys = _resolve_celltype_key(batch, module) + if key is None: + available_keys = [k for k in batch.keys() if isinstance(k, str)] + available_preview = ", ".join(sorted(available_keys)[:10]) + raise ValueError( + "--target-cell-type requested filtering but none of the expected keys (%s) were present." + " Available batch keys: %s%s" + % ( + ", ".join(attempted_keys) if attempted_keys else "none", + available_preview, + "..." if len(available_keys) > 10 else "", + ) + ) + + celltypes = _to_list(batch[key]) + mask_values = [str(ct).lower() == target_celltype.lower() for ct in celltypes] + if not mask_values or not any(mask_values): + continue + + mask = torch.tensor(mask_values, dtype=torch.bool) + filtered = {} + for batch_key, value in batch.items(): + if isinstance(value, torch.Tensor): + mask_device = mask.to(value.device) + selected = value[mask_device] + if selected.shape[0] == 0: + continue + filtered[batch_key] = selected + else: + vals = _to_list(value) + selected = [vals[idx] for idx, keep in enumerate(mask_values) if keep] + if not selected: + continue + filtered[batch_key] = selected + if filtered: + found_target = True + yield filtered + + if target_celltype and not found_target: + raise ValueError( + f"Target cell type '{target_celltype}' not found in any batches for evaluation." + ) + + return _generator() + + eval_loader = _create_filtered_loader(data_module) + + if args.test_time_finetune > 0: + control_pert = data_module.get_control_pert() + run_test_time_finetune( + model, + eval_loader, + args.test_time_finetune, + control_pert, + device=next(model.parameters()).device, + filter_batch_fn=None, + ) + eval_loader = _create_filtered_loader(data_module) + logger.info("Test-time fine-tuning complete.") + + logger.info("Preparing a fixed batch of 64 control cells and enumerating perturbations...") + + control_pert = data_module.get_control_pert() + unique_perts = [] + seen_perts = set() + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + for name in names: + name_value = name.item() if isinstance(name, torch.Tensor) else str(name) + if name_value not in seen_perts: + seen_perts.add(name_value) + unique_perts.append(name_value) + if not unique_perts: + raise RuntimeError("No perturbations found in the provided dataloader.") + + eval_loader = _create_filtered_loader(data_module) + + target_core_n = 64 + accum = {} + + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + if not names: + continue + mask = torch.tensor([str(item) == str(control_pert) for item in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + + current_count = accum.get("_count", 0) + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + mask_device = mask.to(value.device) + selected = value[mask_device][:take].detach().clone() + _append_field(accum, key, selected) + else: + vals = _to_list(value) + selected_vals = [vals[idx] for idx, keep in enumerate(mask.tolist()) if keep][:take] + _append_field(accum, key, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError( + f"Could not assemble {target_core_n} control cells; gathered only {accum.get('_count', 0)}." + ) + + core_cells = {} + for key, parts in accum.items(): + if key == "_count": + continue + if len(parts) == 1: + val = parts[0] + else: + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for part in parts: + merged.extend(_to_list(part)) + val = merged + core_cells[key] = val[:target_core_n] if isinstance(val, torch.Tensor) else _to_list(val)[:target_core_n] + + logger.info("Constructed core_cells batch with size %d.", target_core_n) + + os.makedirs(results_dir_default, exist_ok=True) + baseline_path = os.path.join(results_dir_default, "core_cells_baseline.npy") + _save_numpy_snapshot(core_cells, baseline_path, "baseline core_cells batch") + + perts_order = list(unique_perts) + num_perts = len(perts_order) + output_dim = var_dims["output_dim"] + gene_dim = var_dims.get("gene_dim", 0) + hvg_dim = var_dims.get("hvg_dim", 0) + + logger.info("Running first-pass predictions across %d perturbations...", num_perts) + + first_pass_preds = np.empty((num_perts, target_core_n, output_dim), dtype=np.float32) + first_pass_real = np.empty((num_perts, target_core_n, output_dim), dtype=np.float32) + + embed_key = getattr(data_module, "embed_key", None) or "latent_embedding" + output_space = cfg["data"]["kwargs"].get("output_space", "embedding") + store_counts = output_space in {"gene", "all"} + + first_pass_counts = None + first_pass_counts_pred = None + if store_counts: + feature_dim = hvg_dim if output_space == "gene" and hvg_dim else gene_dim + if feature_dim > 0: + first_pass_counts = np.empty((num_perts, target_core_n, feature_dim), dtype=np.float32) + first_pass_counts_pred = np.empty((num_perts, target_core_n, feature_dim), dtype=np.float32) + else: + store_counts = False + + metadata = { + "pert_name": [], + "celltype_name": [], + "batch": [], + "pert_cell_barcode": [], + "ctrl_cell_barcode": [], + } + + device = next(model.parameters()).device + + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + else: + logger.warning("pert_onehot_map.pt not found at %s; proceeding with zero embeddings", map_path) + pert_onehot_map = {} + + def _prepare_pert_emb(pert_name, length): + vec = pert_onehot_map.get(pert_name) + if vec is None and control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + if vec is None: + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("pert_dim is undefined; cannot create perturbation embedding") + vec = torch.zeros(pert_dim) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) + + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="First pass", unit="pert")): + batch = {} + for key, value in core_cells.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.clone().to(device) + else: + batch[key] = list(value) + + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + try: + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n) + + batch_preds = model.predict_step(batch, p_idx, padded=False) + + batch_size = batch_preds["preds"].shape[0] + metadata["pert_name"].extend(_normalize_field(batch_preds.get("pert_name", pert), batch_size, pert)) + metadata["celltype_name"].extend( + _normalize_field(batch_preds.get("celltype_name"), batch_size, target_celltype) + ) + metadata["batch"].extend( + [None if b is None else str(b) for b in _normalize_field(batch_preds.get("batch"), batch_size)] + ) + metadata["pert_cell_barcode"].extend( + _normalize_field(batch_preds.get("pert_cell_barcode"), batch_size) + ) + metadata["ctrl_cell_barcode"].extend( + _normalize_field(batch_preds.get("ctrl_cell_barcode"), batch_size) + ) + + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + + first_pass_preds[p_idx, :, :] = batch_pred_np + first_pass_real[p_idx, :, :] = batch_real_np + + if store_counts and first_pass_counts is not None and batch_preds.get("pert_cell_counts") is not None: + counts_np = batch_preds["pert_cell_counts"].detach().cpu().numpy().astype(np.float32) + first_pass_counts[p_idx, :, :] = counts_np + + if ( + store_counts + and first_pass_counts_pred is not None + and batch_preds.get("pert_cell_counts_preds") is not None + ): + counts_pred_np = batch_preds["pert_cell_counts_preds"].detach().cpu().numpy().astype(np.float32) + first_pass_counts_pred[p_idx, :, :] = counts_pred_np + + logger.info("First pass complete across %d perturbations.", num_perts) + + if phase_one_only: + real_preds_path = os.path.join(results_dir_default, "core_cells_real_preds_per_pert.npy") + np.save(real_preds_path, first_pass_real, allow_pickle=True) + logger.info( + "Saved real perturbed embeddings for %d perturbations to %s", + num_perts, + real_preds_path, + ) + return + + logger.info("Preparing cached first-pass outputs as inputs for second-pass perturbation sweep...") + + embedding_field_candidates = [ + key + for key, value in core_cells.items() + if isinstance(value, torch.Tensor) and value.dim() == 2 + ] + embedding_field_key = embedding_field_candidates[0] if embedding_field_candidates else None + if embedding_field_key is None: + raise RuntimeError("Unable to identify a 2D tensor field in core_cells for second-pass initialization.") + + double_core_cells = [] + for idx, first_pert in enumerate(perts_order): + snapshot = _clone_core_cells(core_cells) + preds_tensor = torch.tensor(first_pass_preds[idx], device=device, dtype=torch.float32) + real_tensor = torch.tensor(first_pass_real[idx], device=device, dtype=torch.float32) + + snapshot[embedding_field_key] = preds_tensor.clone() + if embedding_field_key != "ctrl_cell_emb" and "ctrl_cell_emb" in snapshot: + snapshot["ctrl_cell_emb"] = preds_tensor.clone() + snapshot["pert_cell_emb"] = real_tensor.clone() + + if store_counts and first_pass_counts is not None: + snapshot["pert_cell_counts"] = torch.tensor( + first_pass_counts[idx], device=device, dtype=torch.float32 + ) + if store_counts and first_pass_counts_pred is not None: + snapshot["pert_cell_counts_preds"] = torch.tensor( + first_pass_counts_pred[idx], device=device, dtype=torch.float32 + ) + + double_core_cells.append((first_pert, snapshot)) + + second_pass_preds = np.empty((num_perts, num_perts, target_core_n, output_dim), dtype=np.float32) + second_pass_real = np.empty_like(second_pass_preds) + second_pass_counts = ( + np.empty((num_perts, num_perts, target_core_n, first_pass_counts.shape[-1]), dtype=np.float32) + if store_counts and first_pass_counts is not None + else None + ) + second_pass_counts_pred = ( + np.empty((num_perts, num_perts, target_core_n, first_pass_counts_pred.shape[-1]), dtype=np.float32) + if store_counts and first_pass_counts_pred is not None + else None + ) + + with torch.no_grad(): + for first_idx, (first_pert, pert_batch) in enumerate( + tqdm(double_core_cells, desc="Second pass", unit="core") + ): + for second_idx, second_pert in enumerate(perts_order): + batch = {} + for key, value in pert_batch.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.clone().to(device) + else: + batch[key] = list(value) + + if "pert_name" in batch: + batch["pert_name"] = [second_pert for _ in range(target_core_n)] + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + try: + idx_val = int(data_module.get_pert_index(second_pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + batch["pert_emb"] = _prepare_pert_emb(second_pert, target_core_n) + + batch_preds = model.predict_step(batch, second_idx, padded=False) + + second_pass_preds[first_idx, second_idx, :, :] = ( + batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + ) + second_pass_real[first_idx, second_idx, :, :] = ( + batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + ) + + if second_pass_counts is not None and batch_preds.get("pert_cell_counts") is not None: + second_pass_counts[first_idx, second_idx, :, :] = ( + batch_preds["pert_cell_counts"].detach().cpu().numpy().astype(np.float32) + ) + + if ( + second_pass_counts_pred is not None + and batch_preds.get("pert_cell_counts_preds") is not None + ): + second_pass_counts_pred[first_idx, second_idx, :, :] = ( + batch_preds["pert_cell_counts_preds"].detach().cpu().numpy().astype(np.float32) + ) + + logger.info( + "Second pass complete: generated double-perturbation predictions across %d x %d combinations.", + num_perts, + num_perts, + ) + + metadata_df = pd.DataFrame(metadata) + if metadata_df.empty: + raise RuntimeError("No metadata collected during first pass; cannot proceed.") + + pert_col = getattr(data_module, "pert_col", None) or "perturbation" + cell_type_col = getattr(data_module, "cell_type_key", None) or "cell_type" + batch_col = getattr(data_module, "batch_col", None) or "batch" + + obs_df = pd.DataFrame( + { + pert_col: metadata_df["pert_name"], + cell_type_col: metadata_df["celltype_name"], + batch_col: metadata_df["batch"], + } + ) + if metadata_df["pert_cell_barcode"].notna().any(): + obs_df["pert_cell_barcode"] = metadata_df["pert_cell_barcode"] + if metadata_df["ctrl_cell_barcode"].notna().any(): + obs_df["ctrl_cell_barcode"] = metadata_df["ctrl_cell_barcode"] + + first_pass_pred_flat = first_pass_preds.reshape(num_perts * target_core_n, output_dim) + first_pass_real_flat = first_pass_real.reshape(num_perts * target_core_n, output_dim) + + if store_counts and first_pass_counts is not None and first_pass_counts_pred is not None: + feature_dim = first_pass_counts.shape[-1] + gene_names = var_dims.get("gene_names") + if gene_names is not None and len(gene_names) == feature_dim: + var_index = pd.Index([str(name) for name in gene_names], name="gene") + else: + var_index = pd.Index([f"feature_{idx}" for idx in range(feature_dim)], name="feature") + var_df = pd.DataFrame(index=var_index) + + pred_X = first_pass_counts_pred.reshape(num_perts * target_core_n, feature_dim) + real_X = first_pass_counts.reshape(num_perts * target_core_n, feature_dim) + else: + var_index = pd.Index([f"embedding_{idx}" for idx in range(output_dim)], name="embedding") + var_df = pd.DataFrame(index=var_index) + pred_X = first_pass_pred_flat + real_X = first_pass_real_flat + + first_pass_pred_adata = anndata.AnnData(X=pred_X, obs=obs_df.copy(), var=var_df.copy()) + first_pass_real_adata = anndata.AnnData(X=real_X, obs=obs_df.copy(), var=var_df.copy()) + first_pass_pred_adata.obsm[embed_key] = first_pass_pred_flat + first_pass_real_adata.obsm[embed_key] = first_pass_real_flat + + second_pass_dir = os.path.join(results_dir_default, "second_pass") + os.makedirs(second_pass_dir, exist_ok=True) + np.save(os.path.join(second_pass_dir, "second_pass_preds.npy"), second_pass_preds) + np.save(os.path.join(second_pass_dir, "second_pass_real.npy"), second_pass_real) + if second_pass_counts is not None: + np.save(os.path.join(second_pass_dir, "second_pass_counts.npy"), second_pass_counts) + if second_pass_counts_pred is not None: + np.save(os.path.join(second_pass_dir, "second_pass_counts_pred.npy"), second_pass_counts_pred) + + second_pass_obs = pd.DataFrame( + { + "first_pert": np.repeat(perts_order, num_perts * target_core_n), + "second_pert": np.tile(np.repeat(perts_order, target_core_n), num_perts), + "core_cell_index": np.tile(np.arange(target_core_n), num_perts * num_perts), + } + ) + second_pass_obs.index = [f"second_pass_cell_{idx}" for idx in range(second_pass_obs.shape[0])] + second_pass_var = pd.DataFrame( + index=pd.Index([f"embedding_{idx}" for idx in range(output_dim)], name="embedding"), + ) + + second_pass_pred_flat = second_pass_preds.reshape(num_perts * num_perts * target_core_n, output_dim) + second_pass_real_flat = second_pass_real.reshape(num_perts * num_perts * target_core_n, output_dim) + + second_pass_adata = anndata.AnnData( + X=second_pass_pred_flat, + obs=second_pass_obs, + var=second_pass_var, + ) + second_pass_adata.obsm[embed_key] = second_pass_pred_flat + second_pass_adata.obsm[f"{embed_key}_baseline"] = second_pass_real_flat + second_pass_adata.write_h5ad(os.path.join(second_pass_dir, "second_pass_preds.h5ad")) + + first_pass_pred_path = os.path.join(results_dir_default, "first_pass_preds.h5ad") + first_pass_real_path = os.path.join(results_dir_default, "first_pass_real.h5ad") + first_pass_pred_adata.write_h5ad(first_pass_pred_path) + first_pass_real_adata.write_h5ad(first_pass_real_path) + logger.info("Saved first-pass predicted adata to %s", first_pass_pred_path) + logger.info("Saved first-pass real adata to %s", first_pass_real_path) + + np.save(os.path.join(results_dir_default, "first_pass_preds.npy"), first_pass_preds) + np.save(os.path.join(results_dir_default, "first_pass_real.npy"), first_pass_real) + if first_pass_counts is not None: + np.save(os.path.join(results_dir_default, "first_pass_counts.npy"), first_pass_counts) + if first_pass_counts_pred is not None: + np.save(os.path.join(results_dir_default, "first_pass_counts_pred.npy"), first_pass_counts_pred) + + if args.predict_only: + return + + if cell_type_col not in first_pass_real_adata.obs.columns: + logger.warning( + "Cell type column '%s' not found in observations; skipping metric computation.", + cell_type_col, + ) + return + + control_pert = data_module.get_control_pert() + ct_split_real = split_anndata_on_celltype( + adata=first_pass_real_adata, + celltype_col=cell_type_col, + ) + ct_split_pred = split_anndata_on_celltype( + adata=first_pass_pred_adata, + celltype_col=cell_type_col, + ) + + if len(ct_split_real) != len(ct_split_pred): + logger.warning( + "Number of celltypes in real and predicted AnnData objects differ (%d vs %d); skipping metrics.", + len(ct_split_real), + len(ct_split_pred), + ) + return + + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for celltype in ct_split_real.keys(): + real_ct = ct_split_real[celltype] + pred_ct = ct_split_pred[celltype] + + metric_configs = {} + if data_module.embed_key and data_module.embed_key != "X_hvg": + metric_configs = { + "discrimination_score": {"embed_key": embed_key}, + "pearson_edistance": {"embed_key": embed_key, "n_jobs": -1}, + } + else: + metric_configs = {"pearson_edistance": {"n_jobs": -1}} + + evaluator = MetricsEvaluator( + adata_pred=pred_ct, + adata_real=real_ct, + control_pert=control_pert, + pert_col=pert_col, + outdir=results_dir_default, + prefix=str(celltype), + pdex_kwargs=pdex_kwargs, + batch_size=2048, + ) + evaluator.compute( + profile=args.profile, + metric_configs=metric_configs, + skip_metrics=["pearson_edistance", "clustering_agreement"], + ) + + +def save_core_cells_real_preds(args: ap.ArgumentParser) -> None: + """Run only phase one of the pipeline and persist real core-cell embeddings per perturbation.""" + return run_tx_double(args, phase_one_only=True) diff --git a/src/state/_cli/_tx/_heatmap.py b/src/state/_cli/_tx/_heatmap.py index 0d8253b4..c162b3f9 100755 --- a/src/state/_cli/_tx/_heatmap.py +++ b/src/state/_cli/_tx/_heatmap.py @@ -72,6 +72,11 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): action="store_true", help="If set, run test-time heat map analysis with position upregulation.", ) + parser.add_argument( + "--phase-one-only", + action="store_true", + help="If set, run only phase one to save core cell real embeddings per perturbation.", + ) parser.add_argument( "--heatmap-output-path", type=str, @@ -84,6 +89,13 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): default=None, help="Directory to save results. If not provided, defaults to /eval_", ) + parser.add_argument( + "--heatmap-snapshots-only", + action="store_true", + help=( + "Compute and persist pathway-upregulated core cell batches without running model inference or generating heatmaps." + ), + ) parser.add_argument( "--annotation-path", type=str, @@ -93,7 +105,7 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): parser.add_argument( "--annotation-field", type=str, - default="go_cc_paths", + default="go_reactome_paths", help=( "Field name in structured annotation data to use for pathway grouping (e.g., 'go_cc_paths'). " "Ignored when loading JSON files that map genes directly to pathways." @@ -101,10 +113,11 @@ def add_arguments_heatmap(parser: ap.ArgumentParser): ) -def run_tx_heatmap(args: ap.ArgumentParser): +def run_tx_heatmap(args: ap.ArgumentParser, *, phase_one_only: bool = False): import logging import os import sys + import copy import anndata import lightning.pytorch as pl @@ -113,6 +126,8 @@ def run_tx_heatmap(args: ap.ArgumentParser): import torch import yaml import json + import uuid + from datetime import datetime import matplotlib.pyplot as plt import matplotlib matplotlib.use('Agg') # Use non-interactive backend @@ -126,6 +141,47 @@ def run_tx_heatmap(args: ap.ArgumentParser): logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + def _prepare_for_serialization(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().copy() + if isinstance(obj, dict): + return {k: _prepare_for_serialization(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_serialization(v) for v in obj] + return obj + + def _save_numpy_snapshot(obj, path, description=None): + serializable = _prepare_for_serialization(obj) + try: + np.save(path, serializable, allow_pickle=True) + if description: + logger.info("Saved %s to %s", description, path) + else: + logger.info("Saved snapshot to %s", path) + except Exception as e: + log_desc = description or "snapshot" + logger.warning("Failed to save %s to %s: %s", log_desc, path, e) + + def _clone_core_cells(src): + cloned = {} + for k, v in src.items(): + if isinstance(v, torch.Tensor): + cloned[k] = v.clone() + else: + try: + cloned[k] = copy.deepcopy(v) + except Exception: + cloned[k] = v + return cloned + + results_dir_default = ( + args.results_dir + if args.results_dir is not None + else os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + ) + + snapshots_only = getattr(args, "heatmap_snapshots_only", False) + torch.multiprocessing.set_sharing_strategy("file_system") def run_test_time_finetune(model, dataloader, ft_epochs, control_pert, device): @@ -287,7 +343,7 @@ def load_config(cfg_path: str) -> dict: logger.warning("No test dataloader found. Exiting.") sys.exit(0) - logger.info("Preparing a fixed batch of 64 control cells (core_cells) and enumerating perturbations...") + logger.info("Preparing a fixed batch of 256 control cells (core_cells) and enumerating perturbations...") # Helper to normalize values to python lists def _to_list(value): @@ -322,8 +378,8 @@ def _to_list(value): else: logger.warning("Control perturbation not observed in test loader perturbation names.") - # Build a single fixed batch of exactly 64 control cells - target_core_n = 64 + # Build a single fixed batch of exactly 256 control cells + target_core_n = 256 core_cells = None accum = {} @@ -345,7 +401,7 @@ def _append_field(store, key, value): # If no names provided in batch, skip (cannot verify control) continue - # Slice each tensor field by mask and accumulate until we have 64 + # Slice each tensor field by mask and accumulate until we have 256 current_count = 0 if "_count" not in accum else accum["_count"] take = min(target_core_n - current_count, int(mask.sum().item())) if take <= 0: @@ -376,7 +432,7 @@ def _append_field(store, key, value): if accum.get("_count", 0) < target_core_n: raise RuntimeError(f"Could not assemble {target_core_n} control cells for core_cells; gathered {accum.get('_count', 0)}.") - # Collate accumulated pieces into a single batch dict of length 64 + # Collate accumulated pieces into a single batch dict of length 256 core_cells = {} for k, parts in accum.items(): if k == "_count": @@ -391,7 +447,7 @@ def _append_field(store, key, value): for p in parts: merged.extend(_to_list(p)) val = merged - # Ensure final length == 64 + # Ensure final length == 256 if isinstance(val, torch.Tensor): core_cells[k] = val[:target_core_n] else: @@ -399,6 +455,10 @@ def _append_field(store, key, value): logger.info(f"Constructed core_cells batch with size {target_core_n}.") + os.makedirs(results_dir_default, exist_ok=True) + baseline_core_cells_path = os.path.join(results_dir_default, "core_cells_baseline.npy") + _save_numpy_snapshot(core_cells, baseline_core_cells_path, description="baseline core_cells batch (control cells)") + # Compute distributions for each position across ALL control cells in the test loader # Strategy: determine a 2D vector key from the first batch, then aggregate all control rows vector_key_candidates = ["ctrl_cell_emb", "pert_cell_emb", "X"] @@ -484,15 +544,52 @@ def apply_shift_to_core_cells(index: int, upregulate: bool): apply_shift_to_core_cells(index=int(args.shift_index), upregulate=(args.shift_direction == "up")) logger.info(f"Applied 2σ {'up' if args.shift_direction=='up' else 'down'} shift at index {int(args.shift_index)} across core_cells") - # Prepare output arrays sized by num_perts * 64 - # Keep all perturbations including control to be explicit + # Prepare perturbation ordering and, if needed, buffers for forward passes perts_order = list(unique_perts) - num_cells = len(perts_order) * target_core_n - output_dim = var_dims["output_dim"] - gene_dim = var_dims["gene_dim"] - hvg_dim = var_dims["hvg_dim"] - logger.info("Generating predictions: one forward pass per perturbation on core_cells...") + if snapshots_only: + logger.info("Heatmap snapshots flag set; skipping phase-one forward passes through the model.") + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + final_preds = None + final_reals = None + final_X_hvg = None + final_pert_cell_counts_preds = None + normal_preds_per_pert = {} + real_preds_per_pert = {} + store_raw_expression = False + else: + logger.info("Generating predictions: one forward pass per perturbation on core_cells...") + num_cells = len(perts_order) * target_core_n + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + + # Phase 1: Normal inference on all perturbations + final_preds = np.empty((num_cells, output_dim), dtype=np.float32) + final_reals = np.empty((num_cells, output_dim), dtype=np.float32) + + # Phase 2: Store normal predictions for distance computation + normal_preds_per_pert = {} # pert_name -> [256, output_dim] array + real_preds_per_pert = {} # pert_name -> [256, output_dim] array + + store_raw_expression = ( + data_module.embed_key is not None + and data_module.embed_key != "X_hvg" + and cfg["data"]["kwargs"]["output_space"] == "gene" + ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") + + final_X_hvg = None + final_pert_cell_counts_preds = None + if store_raw_expression: + # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. + if cfg["data"]["kwargs"]["output_space"] == "gene": + final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) + if cfg["data"]["kwargs"]["output_space"] == "all": + final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) device = next(model.parameters()).device # Prepare perturbation one-hot/embedding map for the pert encoder @@ -525,30 +622,6 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): vec = torch.zeros(pert_dim) return vec.float().unsqueeze(0).repeat(length, 1).to(device) - # Phase 1: Normal inference on all perturbations - final_preds = np.empty((num_cells, output_dim), dtype=np.float32) - final_reals = np.empty((num_cells, output_dim), dtype=np.float32) - - # Phase 2: Store normal predictions for distance computation - normal_preds_per_pert = {} # pert_name -> [64, output_dim] array - - store_raw_expression = ( - data_module.embed_key is not None - and data_module.embed_key != "X_hvg" - and cfg["data"]["kwargs"]["output_space"] == "gene" - ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") - - final_X_hvg = None - final_pert_cell_counts_preds = None - if store_raw_expression: - # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. - if cfg["data"]["kwargs"]["output_space"] == "gene": - final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) - final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) - if cfg["data"]["kwargs"]["output_space"] == "all": - final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) - final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) - current_idx = 0 # Initialize aggregation variables directly @@ -558,96 +631,119 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): all_pert_barcodes = [] all_ctrl_barcodes = [] - with torch.no_grad(): - for p_idx, pert in enumerate(tqdm(perts_order, desc="Predicting", unit="pert")): - # Build a batch by copying core_cells and swapping perturbation - batch = {} - for k, v in core_cells.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to(device) - else: - batch[k] = list(v) - - # Overwrite perturbation fields to target pert - if "pert_name" in batch: - batch["pert_name"] = [pert for _ in range(target_core_n)] - # Best-effort: update any index fields if present and mapping exists - try: - if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): - idx_val = int(data_module.get_pert_index(pert)) - batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) - except Exception: - pass + if not snapshots_only: + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="Predicting", unit="pert")): + # Build a batch by copying core_cells and swapping perturbation + batch = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) - # Ensure perturbation embedding is set for the encoder - batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + # Overwrite perturbation fields to target pert + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + # Best-effort: update any index fields if present and mapping exists + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass - batch_preds = model.predict_step(batch, p_idx, padded=False) + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) - # Extract metadata and data directly from batch_preds - # Handle pert_name - batch_pert_names = [] - if isinstance(batch_preds["pert_name"], list): - all_pert_names.extend(batch_preds["pert_name"]) - batch_pert_names = batch_preds["pert_name"] - else: - all_pert_names.append(batch_preds["pert_name"]) - batch_pert_names = [batch_preds["pert_name"]] + batch_preds = model.predict_step(batch, p_idx, padded=False) - if "pert_cell_barcode" in batch_preds: - if isinstance(batch_preds["pert_cell_barcode"], list): - all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) - all_ctrl_barcodes.extend(batch_preds.get("ctrl_cell_barcode", [None] * len(batch_preds["pert_cell_barcode"])) ) + # Extract metadata and data directly from batch_preds + # Handle pert_name + batch_pert_names = [] + if isinstance(batch_preds["pert_name"], list): + all_pert_names.extend(batch_preds["pert_name"]) + batch_pert_names = batch_preds["pert_name"] else: - all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) - all_ctrl_barcodes.append(batch_preds.get("ctrl_cell_barcode", None)) + all_pert_names.append(batch_preds["pert_name"]) + batch_pert_names = [batch_preds["pert_name"]] - # Handle celltype_name - if isinstance(batch_preds["celltype_name"], list): - all_celltypes.extend(batch_preds["celltype_name"]) - else: - all_celltypes.append(batch_preds["celltype_name"]) + if "pert_cell_barcode" in batch_preds: + if isinstance(batch_preds["pert_cell_barcode"], list): + all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.extend(batch_preds.get("ctrl_cell_barcode", [None] * len(batch_preds["pert_cell_barcode"])) ) + else: + all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.append(batch_preds.get("ctrl_cell_barcode", None)) - # Handle gem_group - if isinstance(batch_preds["batch"], list): - all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) - elif isinstance(batch_preds["batch"], torch.Tensor): - all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) - else: - all_gem_groups.append(str(batch_preds["batch"])) + # Handle celltype_name + if isinstance(batch_preds["celltype_name"], list): + all_celltypes.extend(batch_preds["celltype_name"]) + else: + all_celltypes.append(batch_preds["celltype_name"]) - batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) - batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) - batch_size = batch_pred_np.shape[0] - final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np - final_reals[current_idx : current_idx + batch_size, :] = batch_real_np - - # Store normal predictions for this perturbation for distance computation - normal_preds_per_pert[pert] = batch_pred_np.copy() - - current_idx += batch_size + # Handle gem_group + if isinstance(batch_preds["batch"], list): + all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) + elif isinstance(batch_preds["batch"], torch.Tensor): + all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) + else: + all_gem_groups.append(str(batch_preds["batch"])) - # Handle X_hvg for HVG space ground truth - if final_X_hvg is not None: - batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) - final_X_hvg[current_idx - batch_size : current_idx, :] = batch_real_gene_np + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + batch_size = batch_pred_np.shape[0] + final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np + final_reals[current_idx : current_idx + batch_size, :] = batch_real_np + + # Store normal predictions for this perturbation for distance computation + normal_preds_per_pert[pert] = batch_pred_np.copy() + real_preds_per_pert[pert] = batch_real_np.copy() + + current_idx += batch_size + + # Handle X_hvg for HVG space ground truth + if final_X_hvg is not None: + batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) + final_X_hvg[current_idx - batch_size : current_idx, :] = batch_real_gene_np - # Handle decoded gene predictions if available - if final_pert_cell_counts_preds is not None: - batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) - final_pert_cell_counts_preds[current_idx - batch_size : current_idx, :] = batch_gene_pred_np + # Handle decoded gene predictions if available + if final_pert_cell_counts_preds is not None: + batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) + final_pert_cell_counts_preds[current_idx - batch_size : current_idx, :] = batch_gene_pred_np - logger.info("Phase 1 complete: Normal inference on all perturbations.") + logger.info("Phase 1 complete: Normal inference on all perturbations.") # Phase 2: Run inference with GO MF pathway groups upregulated (only if requested) - if args.test_time_heat_map: + run_phase_one_only = phase_one_only or getattr(args, "phase_one_only", False) + + if run_phase_one_only and not snapshots_only: + os.makedirs(results_dir_default, exist_ok=True) + if not snapshots_only: + real_preds_path = os.path.join(results_dir_default, "core_cells_real_preds_per_pert.npy") + try: + np.save(real_preds_path, real_preds_per_pert, allow_pickle=True) + logger.info( + "Saved real perturbed core cell embeddings for %d perturbations to %s", + len(real_preds_per_pert), + real_preds_path, + ) + except Exception as e: + logger.error("Failed to save core cell real predictions to %s: %s", real_preds_path, e) + raise + return + + if args.test_time_heat_map or snapshots_only: logger.info("Phase 2: Loading GO MF pathway annotations and running pathway-based upregulation...") - if args.results_dir is not None: - results_dir = args.results_dir - else: - results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) - os.makedirs(results_dir, exist_ok=True) + results_dir = results_dir_default + + # Ensure unique heatmap directory per invocation to avoid overwriting prior outputs + timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + unique_suffix = f"{timestamp}_{uuid.uuid4().hex[:8]}" + heatmap_results_dir = os.path.join(results_dir, "heatmap_runs", unique_suffix) + + os.makedirs(heatmap_results_dir, exist_ok=True) annotation_ext = os.path.splitext(args.annotation_path)[1].lower() annotation_source_type = "unknown" annotation_label = args.annotation_field @@ -781,34 +877,30 @@ def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): upregulated_preds_path = None upregulated_preds_memmap = None - if num_pathways == 0: - logger.warning("No pathways passed filtering; skipping upregulated prediction storage.") - else: - try: - upregulated_preds_path = os.path.join( - results_dir, - f"{field_suffix}_pathway_upregulated_preds.npy", - ) - upregulated_preds_memmap = np.memmap( - upregulated_preds_path, - dtype=np.float32, - mode="w+", - shape=(num_pathways, len(perts_order), target_core_n, output_dim), - ) - except Exception as e: - logger.warning("Failed to initialize storage for upregulated predictions: %s", e) - upregulated_preds_path = None - upregulated_preds_memmap = None + if not snapshots_only: + if num_pathways == 0: + logger.warning("No pathways passed filtering; skipping upregulated prediction storage.") + else: + try: + upregulated_preds_path = os.path.join( + heatmap_results_dir, + f"{field_suffix}_pathway_upregulated_preds.npy", + ) + upregulated_preds_memmap = np.memmap( + upregulated_preds_path, + dtype=np.float32, + mode="w+", + shape=(num_pathways, len(perts_order), target_core_n, output_dim), + ) + except Exception as e: + logger.warning("Failed to initialize storage for upregulated predictions: %s", e) + upregulated_preds_path = None + upregulated_preds_memmap = None # Create a copy of core_cells for upregulation experiments - original_core_cells = {} - for k, v in core_cells.items(): - if isinstance(v, torch.Tensor): - original_core_cells[k] = v.clone() - else: - original_core_cells[k] = v.copy() if hasattr(v, 'copy') else v + original_core_cells = _clone_core_cells(core_cells) - def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, target_norm: float = 2.0): + def apply_pathway_shift_to_core_cells(cell_batch: dict, gene_indices: list, upregulate: bool, target_norm: float = 2.0): """Apply shift to multiple gene indices with equivalent euclidean norm across pathways. This function ensures that all pathways receive the same euclidean norm perturbation: @@ -819,11 +911,11 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ - gene_indices: list of 0-indexed gene positions - upregulate: True for positive shift, False for negative shift - target_norm: target euclidean norm for the perturbation (default: 2.0) - Operates in-place on the tensor stored at distributions['key'] inside core_cells. + Operates in-place on the tensor stored at distributions['key'] inside the provided cell_batch. """ - nonlocal core_cells, distributions + nonlocal distributions key = distributions['key'] - tensor = core_cells[key] + tensor = cell_batch[key] if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") @@ -858,54 +950,82 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ for idx in raw_shifts.keys(): shift_value = uniform_shift if upregulate else -uniform_shift tensor[:, idx] = tensor[:, idx] + shift_value + + def compute_pathway_core_cell_snapshots(base_core_cells: dict, pathways: dict) -> list: + snapshots = [] + for pathway_name, gene_indices in pathways.items(): + shifted_cells = _clone_core_cells(base_core_cells) + apply_pathway_shift_to_core_cells(shifted_cells, gene_indices, upregulate=True) + snapshots.append( + { + "pathway_name": pathway_name, + "gene_indices": list(gene_indices), + "core_cells": shifted_cells, + } + ) + return snapshots + + shifted_core_cells_path = os.path.join(heatmap_results_dir, f"{field_suffix}_core_cells_upregulated.npy") + pathway_core_cells_snapshots = compute_pathway_core_cell_snapshots(original_core_cells, filtered_pathways) + _save_numpy_snapshot( + pathway_core_cells_snapshots, + shifted_core_cells_path, + description=f"core_cells upregulated snapshots ({len(pathway_core_cells_snapshots)} pathways)", + ) + if len(pathway_core_cells_snapshots) == 0: + logger.warning("No pathway core cell snapshots generated (0 pathways passed filtering).") + + if not snapshots_only: + with torch.no_grad(): + for pathway_idx, snapshot in enumerate( + tqdm(pathway_core_cells_snapshots, desc="Upregulating pathways", unit="pathway") + ): + core_cells_upregulated = snapshot["core_cells"] + # Run inference for all perturbations with this pathway upregulated + for p_idx, pert in enumerate(perts_order): + # Build batch by copying upregulated core_cells + batch = {} + for k, v in core_cells_upregulated.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + # Get predictions with upregulated pathway + batch_preds = model.predict_step(batch, p_idx, padded=False) + upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap[pathway_idx, p_idx, :, :] = upregulated_preds + + # Compute euclidean distance between normal and upregulated predictions + normal_preds = normal_preds_per_pert[pert] # [256, output_dim] + distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 256 cells + heatmap_distances[pathway_idx, p_idx] = distance - with torch.no_grad(): - for pathway_idx, (pathway_name, gene_indices) in enumerate(tqdm(filtered_pathways.items(), desc="Upregulating pathways", unit="pathway")): - # Apply downregulation to all genes in this pathway - apply_pathway_shift_to_core_cells(gene_indices, upregulate=True) - - # Run inference for all perturbations with this pathway upregulated - for p_idx, pert in enumerate(perts_order): - # Build batch by copying upregulated core_cells - batch = {} - for k, v in core_cells.items(): - if isinstance(v, torch.Tensor): - batch[k] = v.to(device) - else: - batch[k] = list(v) - - # Overwrite perturbation fields - if "pert_name" in batch: - batch["pert_name"] = [pert for _ in range(target_core_n)] - try: - if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): - idx_val = int(data_module.get_pert_index(pert)) - batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) - except Exception: - pass - - # Ensure perturbation embedding is set for the encoder - batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) - - # Get predictions with upregulated pathway - batch_preds = model.predict_step(batch, p_idx, padded=False) - upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) - - if upregulated_preds_memmap is not None: - upregulated_preds_memmap[pathway_idx, p_idx, :, :] = upregulated_preds - - # Compute euclidean distance between normal and upregulated predictions - normal_preds = normal_preds_per_pert[pert] # [64, output_dim] - distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 64 cells - heatmap_distances[pathway_idx, p_idx] = distance - - # Restore original core_cells for next pathway - for k, v in original_core_cells.items(): - if isinstance(v, torch.Tensor): - core_cells[k] = v.clone() - else: - core_cells[k] = v.copy() if hasattr(v, 'copy') else v - + logger.info( + "Phase 2 core cell snapshots ready for %d pathways from annotation source '%s'.", + len(pathway_core_cells_snapshots), + annotation_label or annotation_source_type, + ) + + if snapshots_only: + logger.info("Snapshots-only mode: skipping distance computations, heatmap data, and visualization generation.") + return + logger.info( "Phase 2 complete: Upregulated inference for all pathways from annotation source '%s'.", annotation_label or annotation_source_type, @@ -914,15 +1034,15 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ if upregulated_preds_memmap is not None: upregulated_preds_memmap.flush() logger.info(f"Saved upregulated prediction tensors to {upregulated_preds_path}") - + # Create filename based on annotation field # Save heatmap data try: - heatmap_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") + heatmap_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") np.save(heatmap_path, heatmap_distances) # Save pathway information - pathway_info_path = os.path.join(results_dir, f"{field_suffix}_pathways_info.json") + pathway_info_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathways_info.json") pathway_info = { "pathway_names": pathway_names, "pathway_to_genes": {pathway: genes for pathway, genes in filtered_pathways.items()}, @@ -941,13 +1061,13 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ ), "perturbations": perts_order, "pathway_names": pathway_names, - "distance_type": "mean_euclidean_norm_across_64_cells", + "distance_type": "mean_euclidean_norm_across_256_cells", "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene", "annotation_field": annotation_label if annotation_source_type != "json" else None, "annotation_source_type": annotation_source_type, "upregulated_preds_path": upregulated_preds_path, } - heatmap_meta_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") + heatmap_meta_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") with open(heatmap_meta_path, "w") as f: json.dump(heatmap_meta, f, indent=2) @@ -964,9 +1084,11 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ try: # Determine output path for heatmap image if args.heatmap_output_path is not None: - heatmap_img_path = args.heatmap_output_path + # If user provided a path, make it unique per run as well + base, ext = os.path.splitext(args.heatmap_output_path) + heatmap_img_path = f"{base}_{unique_suffix}{ext or '.png'}" else: - heatmap_img_path = os.path.join(results_dir, f"{field_suffix}_pathway_upregulation_heatmap.png") + heatmap_img_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.png") # Ensure directory exists os.makedirs(os.path.dirname(heatmap_img_path), exist_ok=True) @@ -1095,7 +1217,7 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ ) # Save the AnnData objects - results_dir = os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + results_dir = results_dir_default os.makedirs(results_dir, exist_ok=True) adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") adata_real_path = os.path.join(results_dir, "adata_real.h5ad") @@ -1172,3 +1294,8 @@ def apply_pathway_shift_to_core_cells(gene_indices: list, upregulate: bool, targ else {}, skip_metrics=["pearson_edistance", "clustering_agreement"], ) + + +def save_core_cells_real_preds(args: ap.ArgumentParser): + """Run only phase one of the heatmap pipeline and persist real core-cell embeddings per perturbation.""" + return run_tx_heatmap(args, phase_one_only=True) diff --git a/src/state/_cli/_tx/_heatmap_og.py b/src/state/_cli/_tx/_heatmap_og.py new file mode 100644 index 00000000..5e9a772c --- /dev/null +++ b/src/state/_cli/_tx/_heatmap_og.py @@ -0,0 +1,1301 @@ +import argparse as ap + + +def add_arguments_heatmap(parser: ap.ArgumentParser): + """ + CLI for pathway heatmap analysis with GO MF pathway upregulation. + """ + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Path to the output_dir containing the config.yaml file that was saved during training.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default="last.ckpt", + help="Checkpoint filename. Default is 'last.ckpt'. Relative to the output directory.", + ) + + parser.add_argument( + "--test-time-finetune", + type=int, + default=0, + help="If >0, run test-time fine-tuning for the specified number of epochs on only control cells.", + ) + + parser.add_argument( + "--profile", + type=str, + default="full", + choices=["full", "minimal", "de", "anndata"], + help="run all metrics, minimal, only de metrics, or only output adatas", + ) + + parser.add_argument( + "--predict-only", + action="store_true", + help="If set, only run prediction without evaluation metrics.", + ) + + parser.add_argument( + "--shared-only", + action="store_true", + help=("If set, restrict predictions/evaluation to perturbations shared between train and test (train ∩ test)."), + ) + + parser.add_argument( + "--eval-train-data", + action="store_true", + help="If set, evaluate the model on the training data rather than on the test data.", + ) + + # Optional: apply directional shift on a chosen index using control distributions + parser.add_argument( + "--shift-index", + type=int, + default=None, + help="If set, apply a ±2σ shift to this index across core_cells using control distributions.", + ) + parser.add_argument( + "--shift-direction", + type=str, + default=None, + choices=["up", "down"], + help="Direction for the 2σ shift applied to --shift-index. Requires --shift-index.", + ) + + parser.add_argument( + "--test-time-heat-map", + action="store_true", + help="If set, run test-time heat map analysis with position upregulation.", + ) + parser.add_argument( + "--phase-one-only", + action="store_true", + help="If set, run only phase one to save core cell real embeddings per perturbation.", + ) + parser.add_argument( + "--heatmap-output-path", + type=str, + default=None, + help="Path to save the matplotlib heatmap visualization. If not provided, defaults to /position_upregulation_heatmap.png", + ) + parser.add_argument( + "--results-dir", + type=str, + default=None, + help="Directory to save results. If not provided, defaults to /eval_", + ) + parser.add_argument( + "--heatmap-snapshots-only", + action="store_true", + help=( + "Compute and persist pathway-upregulated core cell batches without running model inference or generating heatmaps." + ), + ) + parser.add_argument( + "--annotation-path", + type=str, + default="/home/dhruvgautam/annotations/replogle_go_annotations.pkl", #/home/dhruvgautam/annotations/var_dims_gene_go_annotations.json + help="Path to the hvg gene annotations file.", + ) + parser.add_argument( + "--annotation-field", + type=str, + default="go_reactome_paths", + help=( + "Field name in structured annotation data to use for pathway grouping (e.g., 'go_cc_paths'). " + "Ignored when loading JSON files that map genes directly to pathways." + ), + ) + + +def run_tx_heatmap(args: ap.ArgumentParser, *, phase_one_only: bool = False): + import logging + import os + import sys + import copy + + import anndata + import lightning.pytorch as pl + import numpy as np + import pandas as pd + import torch + import yaml + import json + import uuid + from datetime import datetime + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Use non-interactive backend + + # Cell-eval for metrics computation + from cell_eval import MetricsEvaluator + from cell_eval.utils import split_anndata_on_celltype + from cell_load.data_modules import PerturbationDataModule + from tqdm import tqdm + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + def _prepare_for_serialization(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().copy() + if isinstance(obj, dict): + return {k: _prepare_for_serialization(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_serialization(v) for v in obj] + return obj + + def _save_numpy_snapshot(obj, path, description=None): + serializable = _prepare_for_serialization(obj) + try: + np.save(path, serializable, allow_pickle=True) + if description: + logger.info("Saved %s to %s", description, path) + else: + logger.info("Saved snapshot to %s", path) + except Exception as e: + log_desc = description or "snapshot" + logger.warning("Failed to save %s to %s: %s", log_desc, path, e) + + def _clone_core_cells(src): + cloned = {} + for k, v in src.items(): + if isinstance(v, torch.Tensor): + cloned[k] = v.clone() + else: + try: + cloned[k] = copy.deepcopy(v) + except Exception: + cloned[k] = v + return cloned + + results_dir_default = ( + args.results_dir + if args.results_dir is not None + else os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + ) + + snapshots_only = getattr(args, "heatmap_snapshots_only", False) + + torch.multiprocessing.set_sharing_strategy("file_system") + + def run_test_time_finetune(model, dataloader, ft_epochs, control_pert, device): + """ + Perform test-time fine-tuning on only control cells. + """ + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + logger.info(f"Starting test-time fine-tuning for {ft_epochs} epoch(s) on control cells only.") + for epoch in range(ft_epochs): + epoch_losses = [] + pbar = tqdm(dataloader, desc=f"Finetune epoch {epoch + 1}/{ft_epochs}", leave=True) + for batch in pbar: + # Check if this batch contains control cells + first_pert = ( + batch["pert_name"][0] if isinstance(batch["pert_name"], list) else batch["pert_name"][0].item() + ) + if first_pert != control_pert: + continue + + # Move batch data to device + batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} + + optimizer.zero_grad() + loss = model.training_step(batch, batch_idx=0, padded=False) + if loss is None: + continue + loss.backward() + optimizer.step() + epoch_losses.append(loss.item()) + pbar.set_postfix(loss=f"{loss.item():.4f}") + + mean_loss = np.mean(epoch_losses) if epoch_losses else float("nan") + logger.info(f"Finetune epoch {epoch + 1}/{ft_epochs}, mean loss: {mean_loss}") + model.eval() + + def load_config(cfg_path: str) -> dict: + """Load config from the YAML file that was dumped during training.""" + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + return cfg + + # 1. Load the config + config_path = os.path.join(args.output_dir, "config.yaml") + cfg = load_config(config_path) + logger.info(f"Loaded config from {config_path}") + + # 2. Find run output directory & load data module + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + if not os.path.isabs(run_output_dir): + run_output_dir = os.path.abspath(run_output_dir) + + if not os.path.exists(run_output_dir): + inferred_run_dir = args.output_dir + if os.path.exists(inferred_run_dir): + logger.warning( + "Run directory %s not found; falling back to config directory %s", + run_output_dir, + inferred_run_dir, + ) + run_output_dir = inferred_run_dir + else: + raise FileNotFoundError( + "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) + ) + + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}?") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + + # Seed everything + pl.seed_everything(cfg["training"]["train_seed"]) + + # 3. Load the trained model + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + checkpoint_path = os.path.join(checkpoint_dir, args.checkpoint) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Could not find checkpoint at {checkpoint_path}.\nSpecify a correct checkpoint filename with --checkpoint." + ) + logger.info("Loading model from %s", checkpoint_path) + + # Determine model class and load + model_class_name = cfg["model"]["name"] + model_kwargs = cfg["model"]["kwargs"] + + # Import the correct model class + if model_class_name.lower() == "embedsum": + from ...tx.models.embed_sum import EmbedSumPerturbationModel + + ModelClass = EmbedSumPerturbationModel + elif model_class_name.lower() == "old_neuralot": + from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel + + ModelClass = OldNeuralOTPerturbationModel + elif model_class_name.lower() in ["neuralot", "pertsets", "state"]: + from ...tx.models.state_transition import StateTransitionPerturbationModel + + ModelClass = StateTransitionPerturbationModel + + elif model_class_name.lower() in ["globalsimplesum", "perturb_mean"]: + from ...tx.models.perturb_mean import PerturbMeanPerturbationModel + + ModelClass = PerturbMeanPerturbationModel + elif model_class_name.lower() in ["celltypemean", "context_mean"]: + from ...tx.models.context_mean import ContextMeanPerturbationModel + + ModelClass = ContextMeanPerturbationModel + elif model_class_name.lower() == "decoder_only": + from ...tx.models.decoder_only import DecoderOnlyPerturbationModel + + ModelClass = DecoderOnlyPerturbationModel + else: + raise ValueError(f"Unknown model class: {model_class_name}") + + var_dims = data_module.get_var_dims() + model_init_kwargs = { + "input_dim": var_dims["input_dim"], + "hidden_dim": model_kwargs["hidden_dim"], + "gene_dim": var_dims["gene_dim"], + "hvg_dim": var_dims["hvg_dim"], + "output_dim": var_dims["output_dim"], + "pert_dim": var_dims["pert_dim"], + **model_kwargs, + } + + model = ModelClass.load_from_checkpoint(checkpoint_path, **model_init_kwargs) + model.eval() + logger.info("Model loaded successfully.") + + # 4. Test-time fine-tuning if requested + data_module.batch_size = 1 + if args.test_time_finetune > 0: + control_pert = data_module.get_control_pert() + if args.eval_train_data: + test_loader = data_module.train_dataloader(test=True) + else: + test_loader = data_module.test_dataloader() + + run_test_time_finetune( + model, test_loader, args.test_time_finetune, control_pert, device=next(model.parameters()).device + ) + logger.info("Test-time fine-tuning complete.") + + # 5. Run inference on test set + data_module.setup(stage="test") + if args.eval_train_data: + scan_loader = data_module.train_dataloader(test=True) + else: + scan_loader = data_module.test_dataloader() + + if scan_loader is None: + logger.warning("No test dataloader found. Exiting.") + sys.exit(0) + + logger.info("Preparing a fixed batch of 64 control cells (core_cells) and enumerating perturbations...") + + # Helper to normalize values to python lists + def _to_list(value): + if isinstance(value, list): + return value + if isinstance(value, torch.Tensor): + try: + return [x.item() if x.dim() == 0 else x for x in value] + except Exception: + return value.tolist() + return [value] + + control_pert = data_module.get_control_pert() + + # Collect unique perturbation names from the loader without running the model + unique_perts = [] + seen_perts = set() + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + for n in names: + if isinstance(n, torch.Tensor): + try: + n = n.item() + except Exception: + n = str(n) + if n not in seen_perts: + seen_perts.add(n) + unique_perts.append(n) + + if control_pert in seen_perts: + logger.info(f"Found {len(unique_perts)} total perturbations (including control '{control_pert}').") + else: + logger.warning("Control perturbation not observed in test loader perturbation names.") + + # Build a single fixed batch of exactly 64 control cells + target_core_n = 64 + core_cells = None + accum = {} + + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + # Iterate again to collect control cells only + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + # Build a mask for control entries when possible + mask = None + if len(names) > 0: + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + else: + # If no names provided in batch, skip (cannot verify control) + continue + + # Slice each tensor field by mask and accumulate until we have 64 + current_count = 0 if "_count" not in accum else accum["_count"] + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + # Identify keys to carry forward; prefer tensors and essential metadata + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + try: + vsel = v[mask][:take].detach().clone() + except Exception: + # fallback: try first dimension slice + vsel = v[:take].detach().clone() + _append_field(accum, k, vsel) + else: + # For non-tensor fields, convert to list and slice by mask when possible + vals = _to_list(v) + try: + selected_vals = [vals[i] for i, m in enumerate(mask.tolist()) if m][:take] + except Exception: + selected_vals = vals[:take] + _append_field(accum, k, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError(f"Could not assemble {target_core_n} control cells for core_cells; gathered {accum.get('_count', 0)}.") + + # Collate accumulated pieces into a single batch dict of length 64 + core_cells = {} + for k, parts in accum.items(): + if k == "_count": + continue + if len(parts) == 1: + val = parts[0] + else: + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for p in parts: + merged.extend(_to_list(p)) + val = merged + # Ensure final length == 64 + if isinstance(val, torch.Tensor): + core_cells[k] = val[:target_core_n] + else: + core_cells[k] = _to_list(val)[:target_core_n] + + logger.info(f"Constructed core_cells batch with size {target_core_n}.") + + os.makedirs(results_dir_default, exist_ok=True) + baseline_core_cells_path = os.path.join(results_dir_default, "core_cells_baseline.npy") + _save_numpy_snapshot(core_cells, baseline_core_cells_path, description="baseline core_cells batch (control cells)") + + # Compute distributions for each position across ALL control cells in the test loader + # Strategy: determine a 2D vector key from the first batch, then aggregate all control rows + vector_key_candidates = ["ctrl_cell_emb", "pert_cell_emb", "X"] + dist_source_key = None + # Find key by peeking one batch + for b in scan_loader: + for cand in vector_key_candidates: + if cand in b and isinstance(b[cand], torch.Tensor) and b[cand].dim() == 2: + dist_source_key = cand + break + if dist_source_key is None: + # fallback: any 2D tensor + for k, v in b.items(): + if isinstance(v, torch.Tensor) and v.dim() == 2: + dist_source_key = k + break + # break after first batch inspected + break + if dist_source_key is None: + raise RuntimeError("Could not find a 2D tensor in test loader batches to compute per-dimension distributions.") + + # Aggregate all control rows for the chosen key + control_rows = [] + for batch in scan_loader: + names = _to_list(batch.get("pert_name", [])) + if len(names) == 0: + continue + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + vec = batch.get(dist_source_key, None) + if isinstance(vec, torch.Tensor) and vec.dim() == 2: + try: + control_rows.append(vec[mask].detach().cpu().float()) + except Exception: + # fallback: take leading rows equal to mask sum + take = int(mask.sum().item()) + if take > 0: + control_rows.append(vec[:take].detach().cpu().float()) + + if len(control_rows) == 0: + raise RuntimeError("No control rows found to compute distributions.") + + control_vectors_all = torch.cat(control_rows, dim=0) # [Nc, D] + D = control_vectors_all.shape[1] + if D != 2000: + logger.warning(f"Expected vector dimension 2000; found {D}. Proceeding with {D} dimensions.") + + control_mean = control_vectors_all.mean(dim=0) + control_std = control_vectors_all.std(dim=0, unbiased=False).clamp_min(1e-8) + + # Save distributions to results directory later; keep in scope for optional shifting + distributions = { + "key": dist_source_key, + "mean": control_mean.numpy(), + "std": control_std.numpy(), + "dim": int(D), + "num_cells": int(control_vectors_all.shape[0]), + } + + def apply_shift_to_core_cells(index: int, upregulate: bool): + """Apply ±2σ shift at a single index across all vectors in core_cells. + + - index: integer in [0, D) + - upregulate: True for +2σ, False for -2σ + Operates in-place on the tensor stored at distributions['key'] inside core_cells. + """ + nonlocal core_cells, distributions + if index < 0 or index >= distributions["dim"]: + raise ValueError(f"Index {index} is out of bounds for dimension {distributions['dim']}") + shift_value = (2.0 if upregulate else -2.0) * float(distributions["std"][index]) + key = distributions["key"] + tensor = core_cells[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + tensor[:, index] = tensor[:, index] + shift_value + core_cells[key] = tensor + + # Optionally apply shift based on CLI flags before running inference + if args.shift_index is not None: + if args.shift_direction is None: + raise ValueError("--shift-direction is required when --shift-index is provided") + apply_shift_to_core_cells(index=int(args.shift_index), upregulate=(args.shift_direction == "up")) + logger.info(f"Applied 2σ {'up' if args.shift_direction=='up' else 'down'} shift at index {int(args.shift_index)} across core_cells") + + # Prepare perturbation ordering and, if needed, buffers for forward passes + perts_order = list(unique_perts) + + if snapshots_only: + logger.info("Heatmap snapshots flag set; skipping phase-one forward passes through the model.") + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + final_preds = None + final_reals = None + final_X_hvg = None + final_pert_cell_counts_preds = None + normal_preds_per_pert = {} + real_preds_per_pert = {} + store_raw_expression = False + else: + logger.info("Generating predictions: one forward pass per perturbation on core_cells...") + num_cells = len(perts_order) * target_core_n + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + + # Phase 1: Normal inference on all perturbations + final_preds = np.empty((num_cells, output_dim), dtype=np.float32) + final_reals = np.empty((num_cells, output_dim), dtype=np.float32) + + # Phase 2: Store normal predictions for distance computation + normal_preds_per_pert = {} # pert_name -> [64, output_dim] array + real_preds_per_pert = {} # pert_name -> [64, output_dim] array + + store_raw_expression = ( + data_module.embed_key is not None + and data_module.embed_key != "X_hvg" + and cfg["data"]["kwargs"]["output_space"] == "gene" + ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") + + final_X_hvg = None + final_pert_cell_counts_preds = None + if store_raw_expression: + # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. + if cfg["data"]["kwargs"]["output_space"] == "gene": + final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) + if cfg["data"]["kwargs"]["output_space"] == "all": + final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) + device = next(model.parameters()).device + + # Prepare perturbation one-hot/embedding map for the pert encoder + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + try: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + else: + logger.warning(f"pert_onehot_map.pt not found at {map_path}; proceeding without explicit pert_emb overrides") + pert_onehot_map = {} + except Exception as e: + logger.warning(f"Failed to load pert_onehot_map.pt: {e}") + pert_onehot_map = {} + + def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): + vec = None + try: + vec = pert_onehot_map.get(pert_name, None) + if vec is None and control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + except Exception: + vec = None + if vec is None: + # Fallback to zeros with model.pert_dim if mapping is unavailable + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("Could not determine pert_dim to build pert_emb") + vec = torch.zeros(pert_dim) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) + + current_idx = 0 + + # Initialize aggregation variables directly + all_pert_names = [] + all_celltypes = [] + all_gem_groups = [] + all_pert_barcodes = [] + all_ctrl_barcodes = [] + + if not snapshots_only: + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="Predicting", unit="pert")): + # Build a batch by copying core_cells and swapping perturbation + batch = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields to target pert + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + # Best-effort: update any index fields if present and mapping exists + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + batch_preds = model.predict_step(batch, p_idx, padded=False) + + # Extract metadata and data directly from batch_preds + # Handle pert_name + batch_pert_names = [] + if isinstance(batch_preds["pert_name"], list): + all_pert_names.extend(batch_preds["pert_name"]) + batch_pert_names = batch_preds["pert_name"] + else: + all_pert_names.append(batch_preds["pert_name"]) + batch_pert_names = [batch_preds["pert_name"]] + + if "pert_cell_barcode" in batch_preds: + if isinstance(batch_preds["pert_cell_barcode"], list): + all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.extend(batch_preds.get("ctrl_cell_barcode", [None] * len(batch_preds["pert_cell_barcode"])) ) + else: + all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.append(batch_preds.get("ctrl_cell_barcode", None)) + + # Handle celltype_name + if isinstance(batch_preds["celltype_name"], list): + all_celltypes.extend(batch_preds["celltype_name"]) + else: + all_celltypes.append(batch_preds["celltype_name"]) + + # Handle gem_group + if isinstance(batch_preds["batch"], list): + all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) + elif isinstance(batch_preds["batch"], torch.Tensor): + all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) + else: + all_gem_groups.append(str(batch_preds["batch"])) + + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + batch_size = batch_pred_np.shape[0] + final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np + final_reals[current_idx : current_idx + batch_size, :] = batch_real_np + + # Store normal predictions for this perturbation for distance computation + normal_preds_per_pert[pert] = batch_pred_np.copy() + real_preds_per_pert[pert] = batch_real_np.copy() + + current_idx += batch_size + + # Handle X_hvg for HVG space ground truth + if final_X_hvg is not None: + batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) + final_X_hvg[current_idx - batch_size : current_idx, :] = batch_real_gene_np + + # Handle decoded gene predictions if available + if final_pert_cell_counts_preds is not None: + batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) + final_pert_cell_counts_preds[current_idx - batch_size : current_idx, :] = batch_gene_pred_np + + logger.info("Phase 1 complete: Normal inference on all perturbations.") + + # Phase 2: Run inference with GO MF pathway groups upregulated (only if requested) + run_phase_one_only = phase_one_only or getattr(args, "phase_one_only", False) + + if run_phase_one_only and not snapshots_only: + os.makedirs(results_dir_default, exist_ok=True) + if not snapshots_only: + real_preds_path = os.path.join(results_dir_default, "core_cells_real_preds_per_pert.npy") + try: + np.save(real_preds_path, real_preds_per_pert, allow_pickle=True) + logger.info( + "Saved real perturbed core cell embeddings for %d perturbations to %s", + len(real_preds_per_pert), + real_preds_path, + ) + except Exception as e: + logger.error("Failed to save core cell real predictions to %s: %s", real_preds_path, e) + raise + return + + if args.test_time_heat_map or snapshots_only: + logger.info("Phase 2: Loading GO MF pathway annotations and running pathway-based upregulation...") + + results_dir = results_dir_default + + # Ensure unique heatmap directory per invocation to avoid overwriting prior outputs + timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + unique_suffix = f"{timestamp}_{uuid.uuid4().hex[:8]}" + heatmap_results_dir = os.path.join(results_dir, "heatmap_runs", unique_suffix) + + os.makedirs(heatmap_results_dir, exist_ok=True) + annotation_ext = os.path.splitext(args.annotation_path)[1].lower() + annotation_source_type = "unknown" + annotation_label = args.annotation_field + field_suffix = ( + (annotation_label or "pathways").replace('_', '').lower() + if (annotation_label or "").strip() + else "pathways" + ) + + # Load gene annotations + import pickle + from collections import defaultdict + + pathway_to_genes = defaultdict(list) + gene_names = var_dims.get("gene_names") + gene_name_to_index = {str(name): idx for idx, name in enumerate(gene_names)} if gene_names is not None else {} + + if annotation_ext == ".json": + annotation_source_type = "json" + annotation_label = os.path.splitext(os.path.basename(args.annotation_path))[0] + field_suffix = annotation_label.replace('_', '').lower() or "pathways" + + with open(args.annotation_path, 'r') as f: + gene_annotations = json.load(f) + + if not isinstance(gene_annotations, dict): + raise ValueError( + f"Expected JSON annotation file {args.annotation_path} to map gene names to pathway collections." + ) + + missing_genes = set() + for gene_name, pathway_data in gene_annotations.items(): + if not pathway_data: + continue + + idx = gene_name_to_index.get(str(gene_name)) + if idx is None: + missing_genes.add(str(gene_name)) + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [] + for entry in pathway_data: + if entry is None: + continue + pathways.append(str(entry).strip()) + pathways = [p for p in pathways if p] + + for pathway in pathways: + pathway_to_genes[pathway].append(idx) + + if missing_genes: + sample_missing = ", ".join(sorted(missing_genes)[:5]) + logger.warning( + "Skipped %d gene(s) from annotation file not present in model gene names (e.g., %s)", + len(missing_genes), + sample_missing, + ) + elif annotation_ext in {".pkl", ".pickle"}: + annotation_source_type = "pickle" + field_suffix = ( + (args.annotation_field or "pathways").replace('_', '').lower() + if (args.annotation_field or "").strip() + else "pathways" + ) + + with open(args.annotation_path, 'rb') as f: + gene_annotations = pickle.load(f) + + if not args.annotation_field: + raise ValueError( + "--annotation-field must be provided when loading pickle annotation files." + ) + + for idx, data in gene_annotations.items(): + pathway_data = None + if isinstance(data, dict): + pathway_data = data.get(args.annotation_field) + else: + try: + pathway_data = data[args.annotation_field] + except (KeyError, TypeError): + pathway_data = getattr(data, args.annotation_field, None) + + if not pathway_data: + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [str(p).strip() for p in pathway_data if str(p).strip()] + + try: + gene_index = int(idx) - 1 + except (TypeError, ValueError): + gene_index = gene_name_to_index.get(str(idx)) + + if gene_index is None or gene_index < 0: + continue + + for pathway in pathways: + pathway_to_genes[pathway].append(gene_index) + else: + raise ValueError( + f"Unsupported annotation file extension '{annotation_ext}' for {args.annotation_path}." + ) + + # Filter out pathways with too few genes (less than 3) to avoid noise + filtered_pathways = {pathway: genes for pathway, genes in pathway_to_genes.items() if len(genes) >= 3} + + logger.info( + "Found %d total pathways from annotation source '%s' (%s)", + len(pathway_to_genes), + annotation_label, + annotation_source_type, + ) + logger.info(f"Using {len(filtered_pathways)} pathways with 3+ genes for upregulation") + + # Initialize heatmap array: [num_pathways, num_perturbations] + num_pathways = len(filtered_pathways) + heatmap_distances = np.zeros((num_pathways, len(perts_order)), dtype=np.float32) + pathway_names = list(filtered_pathways.keys()) + + annotation_label_pretty = (annotation_label or "Annotation").replace('_', ' ').strip() + if annotation_label_pretty: + annotation_label_pretty = annotation_label_pretty.title() + else: + annotation_label_pretty = "Annotation" + + upregulated_preds_path = None + upregulated_preds_memmap = None + if not snapshots_only: + if num_pathways == 0: + logger.warning("No pathways passed filtering; skipping upregulated prediction storage.") + else: + try: + upregulated_preds_path = os.path.join( + heatmap_results_dir, + f"{field_suffix}_pathway_upregulated_preds.npy", + ) + upregulated_preds_memmap = np.memmap( + upregulated_preds_path, + dtype=np.float32, + mode="w+", + shape=(num_pathways, len(perts_order), target_core_n, output_dim), + ) + except Exception as e: + logger.warning("Failed to initialize storage for upregulated predictions: %s", e) + upregulated_preds_path = None + upregulated_preds_memmap = None + + # Create a copy of core_cells for upregulation experiments + original_core_cells = _clone_core_cells(core_cells) + + def apply_pathway_shift_to_core_cells(cell_batch: dict, gene_indices: list, upregulate: bool, target_norm: float = 2.0): + """Apply shift to multiple gene indices with equivalent euclidean norm across pathways. + + This function ensures that all pathways receive the same euclidean norm perturbation: + 1. Compute individual shifts based on 2σ for each gene + 2. Calculate the euclidean norm of the shift vector + 3. Rescale the entire shift vector to match the target euclidean norm + + - gene_indices: list of 0-indexed gene positions + - upregulate: True for positive shift, False for negative shift + - target_norm: target euclidean norm for the perturbation (default: 2.0) + Operates in-place on the tensor stored at distributions['key'] inside the provided cell_batch. + """ + nonlocal distributions + key = distributions['key'] + tensor = cell_batch[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + + if len(gene_indices) == 0: + return + + # Step 1: Compute raw shift values based on 2σ for each gene + raw_shifts = {} + for idx in gene_indices: + if 0 <= idx < distributions["dim"]: + base_shift = 2.0 * float(distributions["std"][idx]) + raw_shifts[idx] = base_shift if upregulate else -base_shift + + if len(raw_shifts) == 0: + return + + # Step 2: Calculate euclidean norm of the raw shift vector + shift_values = np.array(list(raw_shifts.values())) + current_norm = np.linalg.norm(shift_values) + + # Step 3: Rescale to target norm if current norm > 0 + if current_norm > 1e-8: # Avoid division by zero + scale_factor = target_norm / current_norm + + # Apply rescaled shifts + for idx, raw_shift in raw_shifts.items(): + scaled_shift = raw_shift * scale_factor + tensor[:, idx] = tensor[:, idx] + scaled_shift + else: + # Fallback: if all std deviations are zero, apply uniform shift + uniform_shift = target_norm / np.sqrt(len(raw_shifts)) + for idx in raw_shifts.keys(): + shift_value = uniform_shift if upregulate else -uniform_shift + tensor[:, idx] = tensor[:, idx] + shift_value + + def compute_pathway_core_cell_snapshots(base_core_cells: dict, pathways: dict) -> list: + snapshots = [] + for pathway_name, gene_indices in pathways.items(): + shifted_cells = _clone_core_cells(base_core_cells) + apply_pathway_shift_to_core_cells(shifted_cells, gene_indices, upregulate=True) + snapshots.append( + { + "pathway_name": pathway_name, + "gene_indices": list(gene_indices), + "core_cells": shifted_cells, + } + ) + return snapshots + + shifted_core_cells_path = os.path.join(heatmap_results_dir, f"{field_suffix}_core_cells_upregulated.npy") + pathway_core_cells_snapshots = compute_pathway_core_cell_snapshots(original_core_cells, filtered_pathways) + _save_numpy_snapshot( + pathway_core_cells_snapshots, + shifted_core_cells_path, + description=f"core_cells upregulated snapshots ({len(pathway_core_cells_snapshots)} pathways)", + ) + if len(pathway_core_cells_snapshots) == 0: + logger.warning("No pathway core cell snapshots generated (0 pathways passed filtering).") + + if not snapshots_only: + with torch.no_grad(): + for pathway_idx, snapshot in enumerate( + tqdm(pathway_core_cells_snapshots, desc="Upregulating pathways", unit="pathway") + ): + core_cells_upregulated = snapshot["core_cells"] + # Run inference for all perturbations with this pathway upregulated + for p_idx, pert in enumerate(perts_order): + # Build batch by copying upregulated core_cells + batch = {} + for k, v in core_cells_upregulated.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + # Get predictions with upregulated pathway + batch_preds = model.predict_step(batch, p_idx, padded=False) + upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap[pathway_idx, p_idx, :, :] = upregulated_preds + + # Compute euclidean distance between normal and upregulated predictions + normal_preds = normal_preds_per_pert[pert] # [64, output_dim] + distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 64 cells + heatmap_distances[pathway_idx, p_idx] = distance + + logger.info( + "Phase 2 core cell snapshots ready for %d pathways from annotation source '%s'.", + len(pathway_core_cells_snapshots), + annotation_label or annotation_source_type, + ) + + if snapshots_only: + logger.info("Snapshots-only mode: skipping distance computations, heatmap data, and visualization generation.") + return + + logger.info( + "Phase 2 complete: Upregulated inference for all pathways from annotation source '%s'.", + annotation_label or annotation_source_type, + ) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap.flush() + logger.info(f"Saved upregulated prediction tensors to {upregulated_preds_path}") + + # Create filename based on annotation field + # Save heatmap data + try: + heatmap_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") + np.save(heatmap_path, heatmap_distances) + + # Save pathway information + pathway_info_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathways_info.json") + pathway_info = { + "pathway_names": pathway_names, + "pathway_to_genes": {pathway: genes for pathway, genes in filtered_pathways.items()}, + "total_pathways": len(pathway_to_genes), + "filtered_pathways": len(filtered_pathways), + "min_genes_per_pathway": 3 + } + with open(pathway_info_path, "w") as f: + json.dump(pathway_info, f, indent=2) + + # Save metadata for the heatmap + heatmap_meta = { + "shape": [num_pathways, len(perts_order)], + "description": ( + f"Euclidean distance heatmap: rows={annotation_label_pretty} pathways, cols=perturbations" + ), + "perturbations": perts_order, + "pathway_names": pathway_names, + "distance_type": "mean_euclidean_norm_across_64_cells", + "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene", + "annotation_field": annotation_label if annotation_source_type != "json" else None, + "annotation_source_type": annotation_source_type, + "upregulated_preds_path": upregulated_preds_path, + } + heatmap_meta_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") + with open(heatmap_meta_path, "w") as f: + json.dump(heatmap_meta, f, indent=2) + + logger.info( + "Saved %s pathway upregulation heatmap to %s", + annotation_label_pretty, + heatmap_path, + ) + logger.info(f"Heatmap shape: {heatmap_distances.shape} (pathways x perturbations)") + except Exception as e: + logger.warning(f"Failed to save heatmap data: {e}") + + # Create and save matplotlib heatmap visualization + try: + # Determine output path for heatmap image + if args.heatmap_output_path is not None: + # If user provided a path, make it unique per run as well + base, ext = os.path.splitext(args.heatmap_output_path) + heatmap_img_path = f"{base}_{unique_suffix}{ext or '.png'}" + else: + heatmap_img_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.png") + + # Ensure directory exists + os.makedirs(os.path.dirname(heatmap_img_path), exist_ok=True) + + # Create the heatmap with appropriate size + fig_width = max(12, len(perts_order) * 0.3) + fig_height = max(8, num_pathways * 0.05) # Smaller height per pathway since we have fewer rows + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # Create heatmap with proper labels + im = ax.imshow(heatmap_distances, cmap='viridis', aspect='auto') + + # Set labels and title + ax.set_xlabel('Perturbations') + ax.set_ylabel(f'{args.annotation_field.replace("_", " ").title()} Pathways') + ax.set_title(f'{args.annotation_field.replace("_", " ").title()} Pathway Upregulation Impact Heatmap\n(Euclidean Distance from Normal Predictions)') + + # Set x-axis labels (perturbations) + ax.set_xticks(range(len(perts_order))) + ax.set_xticklabels(perts_order, rotation=45, ha='right', fontsize=8) + + # Set y-axis labels (pathways) - show pathway names, truncated if too long + ax.set_yticks(range(num_pathways)) + truncated_pathway_names = [] + for pathway_name in pathway_names: + # Remove common prefixes and truncate long names + clean_name = pathway_name + # Remove common GO prefixes + for prefix in ['GOMF_', 'GOCC_', 'GOBP_']: + clean_name = clean_name.replace(prefix, '') + if len(clean_name) > 30: + clean_name = clean_name[:27] + '...' + truncated_pathway_names.append(clean_name) + ax.set_yticklabels(truncated_pathway_names, fontsize=6) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label('Mean Euclidean Distance', rotation=270, labelpad=20) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the figure + plt.savefig(heatmap_img_path, dpi=300, bbox_inches='tight') + plt.close(fig) # Close to free memory + + logger.info(f"Saved {args.annotation_field} pathway heatmap visualization to {heatmap_img_path}") + + except Exception as e: + logger.warning(f"Failed to create heatmap visualization: {e}") + else: + logger.info("Skipping heatmap analysis (--test-time-heat-map not set)") + + logger.info("Creating anndatas from predictions from manual loop...") + + # Build pandas DataFrame for obs and var + df_dict = { + data_module.pert_col: all_pert_names, + data_module.cell_type_key: all_celltypes, + data_module.batch_col: all_gem_groups, + } + + if len(all_pert_barcodes) > 0: + df_dict["pert_cell_barcode"] = all_pert_barcodes + df_dict["ctrl_cell_barcode"] = all_ctrl_barcodes + + obs = pd.DataFrame(df_dict) + + gene_names = var_dims["gene_names"] + var = pd.DataFrame({"gene_names": gene_names}) + + if final_X_hvg is not None: + if len(gene_names) != final_pert_cell_counts_preds.shape[1]: + gene_names = np.load( + "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + ) + var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - using the decoded gene expression values + adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs, var=var) + # Create adata for real - using the true gene expression values + adata_real = anndata.AnnData(X=final_X_hvg, obs=obs, var=var) + + # add the embedding predictions + adata_pred.obsm[data_module.embed_key] = final_preds + adata_real.obsm[data_module.embed_key] = final_reals + logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + else: + # if len(gene_names) != final_preds.shape[1]: + # gene_names = np.load( + # "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + # ) + # var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - model was trained on gene expression space already + # adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) + adata_pred = anndata.AnnData(X=final_preds, obs=obs) + # Create adata for real - using the true gene expression values + # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) + adata_real = anndata.AnnData(X=final_reals, obs=obs) + + # Optionally filter to perturbations seen in at least one training context + if args.shared_only: + try: + shared_perts = data_module.get_shared_perturbations() + if len(shared_perts) == 0: + logger.warning("No shared perturbations between train and test; skipping filtering.") + else: + logger.info( + "Filtering to %d shared perturbations present in train ∩ test.", + len(shared_perts), + ) + mask = adata_pred.obs[data_module.pert_col].isin(shared_perts) + before_n = adata_pred.n_obs + adata_pred = adata_pred[mask].copy() + adata_real = adata_real[mask].copy() + logger.info( + "Filtered cells: %d -> %d (kept only seen perturbations)", + before_n, + adata_pred.n_obs, + ) + except Exception as e: + logger.warning( + "Failed to filter by shared perturbations (%s). Proceeding without filter.", + str(e), + ) + + # Save the AnnData objects + results_dir = results_dir_default + os.makedirs(results_dir, exist_ok=True) + adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") + adata_real_path = os.path.join(results_dir, "adata_real.h5ad") + + adata_pred.write_h5ad(adata_pred_path) + adata_real.write_h5ad(adata_real_path) + + logger.info(f"Saved adata_pred to {adata_pred_path}") + logger.info(f"Saved adata_real to {adata_real_path}") + + # Save per-dimension control-cell distributions for reproducibility + try: + dist_out = { + "key": distributions["key"], + "dim": distributions["dim"], + "num_cells": distributions["num_cells"], + } + dist_out_path = os.path.join(results_dir, "control_distributions.meta.json") + with open(dist_out_path, "w") as f: + json.dump(dist_out, f) + np.save(os.path.join(results_dir, "control_mean.npy"), distributions["mean"]) # [D] + np.save(os.path.join(results_dir, "control_std.npy"), distributions["std"]) # [D] + logger.info("Saved control-cell per-dimension mean/std distributions") + except Exception as e: + logger.warning(f"Failed to save control-cell distributions: {e}") + + if not args.predict_only: + # 6. Compute metrics using cell-eval + logger.info("Computing metrics using cell-eval...") + + control_pert = data_module.get_control_pert() + + ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key) + ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key) + + assert len(ct_split_real) == len(ct_split_pred), ( + f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" + ) + + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for ct in ct_split_real.keys(): + real_ct = ct_split_real[ct] + pred_ct = ct_split_pred[ct] + + evaluator = MetricsEvaluator( + adata_pred=pred_ct, + adata_real=real_ct, + control_pert=control_pert, + pert_col=data_module.pert_col, + outdir=results_dir, + prefix=ct, + pdex_kwargs=pdex_kwargs, + batch_size=2048, + ) + + evaluator.compute( + profile=args.profile, + metric_configs={ + "discrimination_score": { + "embed_key": data_module.embed_key, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + "pearson_edistance": { + "embed_key": data_module.embed_key, + "n_jobs": -1, # set to all available cores + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else { + "n_jobs": -1, + }, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + skip_metrics=["pearson_edistance", "clustering_agreement"], + ) + + +def save_core_cells_real_preds(args: ap.ArgumentParser): + """Run only phase one of the heatmap pipeline and persist real core-cell embeddings per perturbation.""" + return run_tx_heatmap(args, phase_one_only=True) diff --git a/src/state/_cli/_tx/_heatmap_train.py b/src/state/_cli/_tx/_heatmap_train.py new file mode 100644 index 00000000..534c8efd --- /dev/null +++ b/src/state/_cli/_tx/_heatmap_train.py @@ -0,0 +1,1424 @@ +import argparse as ap + + +def add_arguments_heatmap(parser: ap.ArgumentParser): + """ + CLI for pathway heatmap analysis with GO MF pathway upregulation. + """ + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help="Path to the output_dir containing the config.yaml file that was saved during training.", + ) + parser.add_argument( + "--checkpoint", + type=str, + default="last.ckpt", + help="Checkpoint filename. Default is 'last.ckpt'. Relative to the output directory.", + ) + + parser.add_argument( + "--test-time-finetune", + type=int, + default=0, + help="If >0, run test-time fine-tuning for the specified number of epochs on only control cells.", + ) + + parser.add_argument( + "--profile", + type=str, + default="full", + choices=["full", "minimal", "de", "anndata"], + help="run all metrics, minimal, only de metrics, or only output adatas", + ) + + parser.add_argument( + "--predict-only", + action="store_true", + help="If set, only run prediction without evaluation metrics.", + ) + + parser.add_argument( + "--shared-only", + action="store_true", + help=("If set, restrict predictions/evaluation to perturbations shared between train and test (train ∩ test)."), + ) + + parser.add_argument( + "--eval-train-data", + action="store_true", + help="If set, evaluate the model on the training data rather than on the test data.", + ) + parser.add_argument( + "--eval-cell-type", + type=str, + default=None, + help=( + "If provided, restrict inference and metrics to the specified cell type; applies to train/test loaders." + ), + ) + + # Optional: apply directional shift on a chosen index using control distributions + parser.add_argument( + "--shift-index", + type=int, + default=None, + help="If set, apply a ±2σ shift to this index across core_cells using control distributions.", + ) + parser.add_argument( + "--shift-direction", + type=str, + default=None, + choices=["up", "down"], + help="Direction for the 2σ shift applied to --shift-index. Requires --shift-index.", + ) + + parser.add_argument( + "--test-time-heat-map", + action="store_true", + help="If set, run test-time heat map analysis with position upregulation.", + ) + parser.add_argument( + "--phase-one-only", + action="store_true", + help="If set, run only phase one to save core cell real embeddings per perturbation.", + ) + parser.add_argument( + "--heatmap-output-path", + type=str, + default=None, + help="Path to save the matplotlib heatmap visualization. If not provided, defaults to /position_upregulation_heatmap.png", + ) + parser.add_argument( + "--results-dir", + type=str, + default=None, + help="Directory to save results. If not provided, defaults to /eval_", + ) + parser.add_argument( + "--heatmap-snapshots-only", + action="store_true", + help=( + "Compute and persist pathway-upregulated core cell batches without running model inference or generating heatmaps." + ), + ) + parser.add_argument( + "--annotation-path", + type=str, + default="/home/dhruvgautam/annotations/replogle_go_annotations.pkl", #/home/dhruvgautam/annotations/var_dims_gene_go_annotations.json + help="Path to the hvg gene annotations file.", + ) + parser.add_argument( + "--annotation-field", + type=str, + default="go_reactome_paths", + help=( + "Field name in structured annotation data to use for pathway grouping (e.g., 'go_cc_paths'). " + "Ignored when loading JSON files that map genes directly to pathways." + ), + ) + + +def run_tx_heatmap(args: ap.ArgumentParser, *, phase_one_only: bool = False): + import logging + import os + import sys + import copy + + import anndata + import lightning.pytorch as pl + import numpy as np + import pandas as pd + import torch + import yaml + import json + import uuid + from datetime import datetime + import matplotlib.pyplot as plt + import matplotlib + matplotlib.use('Agg') # Use non-interactive backend + + # Cell-eval for metrics computation + from cell_eval import MetricsEvaluator + from cell_eval.utils import split_anndata_on_celltype + from cell_load.data_modules import PerturbationDataModule + from tqdm import tqdm + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + def _prepare_for_serialization(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().copy() + if isinstance(obj, dict): + return {k: _prepare_for_serialization(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_serialization(v) for v in obj] + return obj + + def _save_numpy_snapshot(obj, path, description=None): + serializable = _prepare_for_serialization(obj) + try: + np.save(path, serializable, allow_pickle=True) + if description: + logger.info("Saved %s to %s", description, path) + else: + logger.info("Saved snapshot to %s", path) + except Exception as e: + log_desc = description or "snapshot" + logger.warning("Failed to save %s to %s: %s", log_desc, path, e) + + def _clone_core_cells(src): + cloned = {} + for k, v in src.items(): + if isinstance(v, torch.Tensor): + cloned[k] = v.clone() + else: + try: + cloned[k] = copy.deepcopy(v) + except Exception: + cloned[k] = v + return cloned + + results_dir_default = ( + args.results_dir + if args.results_dir is not None + else os.path.join(args.output_dir, "eval_" + os.path.basename(args.checkpoint)) + ) + + snapshots_only = getattr(args, "heatmap_snapshots_only", False) + + torch.multiprocessing.set_sharing_strategy("file_system") + + def _to_list(value): + if isinstance(value, list): + return value + if isinstance(value, torch.Tensor): + try: + return [x.item() if x.dim() == 0 else x for x in value] + except Exception: + return value.tolist() + if isinstance(value, (tuple, set)): + return list(value) + return [value] + + def _resolve_celltype_key(batch): + candidate_keys = [] + base_key = getattr(data_module, "cell_type_key", None) + if base_key: + candidate_keys.append(base_key) + alias_keys = getattr(data_module, "cell_type_key_aliases", None) + if isinstance(alias_keys, (list, tuple)): + candidate_keys.extend(alias_keys) + alias_keys_alt = getattr(data_module, "celltype_key_aliases", None) + if isinstance(alias_keys_alt, (list, tuple)): + candidate_keys.extend(alias_keys_alt) + candidate_keys.extend( + [ + "celltype_name", + "cell_type", + "celltype", + "cell_line", + ] + ) + seen = set() + ordered_candidates = [] + for key in candidate_keys: + if not key or key in seen: + continue + seen.add(key) + ordered_candidates.append(key) + if key in batch: + return key, ordered_candidates + return None, ordered_candidates + + celltype_batches_seen = False + + def _filter_batch_by_celltype(batch, *, target_celltype): + nonlocal celltype_batches_seen + if target_celltype is None: + return batch + celltype_key, attempted_keys = _resolve_celltype_key(batch) + if celltype_key is None: + available_keys = [k for k in batch.keys() if isinstance(k, str)] + available_preview = ", ".join(sorted(available_keys)[:10]) + raise ValueError( + "--eval-cell-type requested cell type filtering but none of the expected keys (%s) were present in batch data. Available batch keys: %s%s" + % ( + ", ".join(attempted_keys) if attempted_keys else "none", + available_preview, + "..." if len(available_keys) > 10 else "", + ) + ) + + target_norm = str(target_celltype).lower() + celltypes = _to_list(batch[celltype_key]) + mask_values = [] + for ct in celltypes: + try: + match = str(ct).lower() == target_norm + except Exception: + match = False + mask_values.append(match) + + if not mask_values: + return None + + mask = torch.tensor(mask_values, dtype=torch.bool) + if mask.sum().item() == 0: + return None + + filtered = {} + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + try: + mask_device = mask.to(v.device) if mask.device != v.device else mask + filtered_val = v[mask_device] + if filtered_val.shape[0] == 0: + return None + filtered[k] = filtered_val + except Exception: + filtered[k] = v + else: + vals = _to_list(v) + mask_list = mask.tolist() + filtered[k] = [vals[i] for i, keep in enumerate(mask_list[: len(vals)]) if keep] + celltype_batches_seen = True + return filtered + + def _iter_batches(dataloader, *, target_celltype=None): + if target_celltype is None: + for batch in dataloader: + yield batch + return + for batch in dataloader: + filtered = _filter_batch_by_celltype(batch, target_celltype=target_celltype) + if filtered is None: + continue + yield filtered + + def run_test_time_finetune(model, dataloader, ft_epochs, control_pert, device, *, filter_batch_fn=None): + """ + Perform test-time fine-tuning on only control cells. + """ + model.train() + optimizer = torch.optim.Adam(model.parameters(), lr=1e-5) + + logger.info(f"Starting test-time fine-tuning for {ft_epochs} epoch(s) on control cells only.") + for epoch in range(ft_epochs): + epoch_losses = [] + pbar = tqdm(dataloader, desc=f"Finetune epoch {epoch + 1}/{ft_epochs}", leave=True) + for batch in pbar: + if filter_batch_fn is not None: + batch = filter_batch_fn(batch) + if batch is None: + continue + # Check if this batch contains control cells + first_pert = ( + batch["pert_name"][0] if isinstance(batch["pert_name"], list) else batch["pert_name"][0].item() + ) + if first_pert != control_pert: + continue + + # Move batch data to device + batch = {k: (v.to(device) if torch.is_tensor(v) else v) for k, v in batch.items()} + + optimizer.zero_grad() + loss = model.training_step(batch, batch_idx=0, padded=False) + if loss is None: + continue + loss.backward() + optimizer.step() + epoch_losses.append(loss.item()) + pbar.set_postfix(loss=f"{loss.item():.4f}") + + mean_loss = np.mean(epoch_losses) if epoch_losses else float("nan") + logger.info(f"Finetune epoch {epoch + 1}/{ft_epochs}, mean loss: {mean_loss}") + model.eval() + + def load_config(cfg_path: str) -> dict: + """Load config from the YAML file that was dumped during training.""" + if not os.path.exists(cfg_path): + raise FileNotFoundError(f"Could not find config file: {cfg_path}") + with open(cfg_path, "r") as f: + cfg = yaml.safe_load(f) + return cfg + + # 1. Load the config + config_path = os.path.join(args.output_dir, "config.yaml") + cfg = load_config(config_path) + logger.info(f"Loaded config from {config_path}") + + # 2. Find run output directory & load data module + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + if not os.path.isabs(run_output_dir): + run_output_dir = os.path.abspath(run_output_dir) + + if not os.path.exists(run_output_dir): + inferred_run_dir = args.output_dir + if os.path.exists(inferred_run_dir): + logger.warning( + "Run directory %s not found; falling back to config directory %s", + run_output_dir, + inferred_run_dir, + ) + run_output_dir = inferred_run_dir + else: + raise FileNotFoundError( + "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) + ) + + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}?") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + + # Seed everything + pl.seed_everything(cfg["training"]["train_seed"]) + + # 3. Load the trained model + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + checkpoint_path = os.path.join(checkpoint_dir, args.checkpoint) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Could not find checkpoint at {checkpoint_path}.\nSpecify a correct checkpoint filename with --checkpoint." + ) + logger.info("Loading model from %s", checkpoint_path) + + # Determine model class and load + model_class_name = cfg["model"]["name"] + model_kwargs = cfg["model"]["kwargs"] + + # Import the correct model class + if model_class_name.lower() == "embedsum": + from ...tx.models.embed_sum import EmbedSumPerturbationModel + + ModelClass = EmbedSumPerturbationModel + elif model_class_name.lower() == "old_neuralot": + from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel + + ModelClass = OldNeuralOTPerturbationModel + elif model_class_name.lower() in ["neuralot", "pertsets", "state"]: + from ...tx.models.state_transition import StateTransitionPerturbationModel + + ModelClass = StateTransitionPerturbationModel + + elif model_class_name.lower() in ["globalsimplesum", "perturb_mean"]: + from ...tx.models.perturb_mean import PerturbMeanPerturbationModel + + ModelClass = PerturbMeanPerturbationModel + elif model_class_name.lower() in ["celltypemean", "context_mean"]: + from ...tx.models.context_mean import ContextMeanPerturbationModel + + ModelClass = ContextMeanPerturbationModel + elif model_class_name.lower() == "decoder_only": + from ...tx.models.decoder_only import DecoderOnlyPerturbationModel + + ModelClass = DecoderOnlyPerturbationModel + else: + raise ValueError(f"Unknown model class: {model_class_name}") + + var_dims = data_module.get_var_dims() + model_init_kwargs = { + "input_dim": var_dims["input_dim"], + "hidden_dim": model_kwargs["hidden_dim"], + "gene_dim": var_dims["gene_dim"], + "hvg_dim": var_dims["hvg_dim"], + "output_dim": var_dims["output_dim"], + "pert_dim": var_dims["pert_dim"], + **model_kwargs, + } + + model = ModelClass.load_from_checkpoint(checkpoint_path, **model_init_kwargs) + model.eval() + logger.info("Model loaded successfully.") + + # 4. Test-time fine-tuning if requested + data_module.batch_size = 1 + filter_celltype = getattr(args, "eval_cell_type", None) + + if args.test_time_finetune > 0: + control_pert = data_module.get_control_pert() + run_test_time_finetune( + model, + data_module.train_dataloader(test=True) if args.eval_train_data else data_module.test_dataloader(), + args.test_time_finetune, + control_pert, + device=next(model.parameters()).device, + filter_batch_fn=( + (lambda batch: _filter_batch_by_celltype(batch, target_celltype=filter_celltype)) + if filter_celltype is not None + else None + ), + ) + logger.info("Test-time fine-tuning complete.") + + # 5. Run inference on test set + data_module.setup(stage="test") + base_scan_loader = data_module.train_dataloader(test=True) if args.eval_train_data else data_module.test_dataloader() + + celltype_filter = getattr(args, "eval_cell_type", None) + + def _scan_batches(): + return _iter_batches(base_scan_loader, target_celltype=celltype_filter) if celltype_filter is not None else base_scan_loader + + scan_loader = _scan_batches() + + if scan_loader is None: + logger.warning("No test dataloader found. Exiting.") + sys.exit(0) + + logger.info("Preparing a fixed batch of 64 control cells (core_cells) and enumerating perturbations...") + + control_pert = data_module.get_control_pert() + + # Collect unique perturbation names from the loader without running the model + unique_perts = [] + seen_perts = set() + for batch in _scan_batches(): + names = _to_list(batch.get("pert_name", [])) + for n in names: + if isinstance(n, torch.Tensor): + try: + n = n.item() + except Exception: + n = str(n) + if n not in seen_perts: + seen_perts.add(n) + unique_perts.append(n) + + if celltype_filter is not None and not celltype_batches_seen: + raise ValueError( + f"Requested eval cell type '{celltype_filter}' not found in data loader batches; cannot proceed." + ) + + if control_pert in seen_perts: + logger.info(f"Found {len(unique_perts)} total perturbations (including control '{control_pert}').") + else: + logger.warning("Control perturbation not observed in test loader perturbation names.") + + # Build a single fixed batch of exactly 64 control cells + target_core_n = 64 + core_cells = None + accum = {} + + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + # Iterate again to collect control cells only + for batch in _scan_batches(): + names = _to_list(batch.get("pert_name", [])) + # Build a mask for control entries when possible + mask = None + if len(names) > 0: + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + else: + # If no names provided in batch, skip (cannot verify control) + continue + + # Slice each tensor field by mask and accumulate until we have 64 + current_count = 0 if "_count" not in accum else accum["_count"] + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + # Identify keys to carry forward; prefer tensors and essential metadata + for k, v in batch.items(): + if isinstance(v, torch.Tensor): + try: + vsel = v[mask][:take].detach().clone() + except Exception: + # fallback: try first dimension slice + vsel = v[:take].detach().clone() + _append_field(accum, k, vsel) + else: + # For non-tensor fields, convert to list and slice by mask when possible + vals = _to_list(v) + try: + selected_vals = [vals[i] for i, m in enumerate(mask.tolist()) if m][:take] + except Exception: + selected_vals = vals[:take] + _append_field(accum, k, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError(f"Could not assemble {target_core_n} control cells for core_cells; gathered {accum.get('_count', 0)}.") + + # Collate accumulated pieces into a single batch dict of length 64 + core_cells = {} + for k, parts in accum.items(): + if k == "_count": + continue + if len(parts) == 1: + val = parts[0] + else: + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for p in parts: + merged.extend(_to_list(p)) + val = merged + # Ensure final length == 64 + if isinstance(val, torch.Tensor): + core_cells[k] = val[:target_core_n] + else: + core_cells[k] = _to_list(val)[:target_core_n] + + logger.info(f"Constructed core_cells batch with size {target_core_n}.") + + os.makedirs(results_dir_default, exist_ok=True) + baseline_core_cells_path = os.path.join(results_dir_default, "core_cells_baseline.npy") + _save_numpy_snapshot(core_cells, baseline_core_cells_path, description="baseline core_cells batch (control cells)") + + # Compute distributions for each position across ALL control cells in the test loader + # Strategy: determine a 2D vector key from the first batch, then aggregate all control rows + vector_key_candidates = ["ctrl_cell_emb", "pert_cell_emb", "X"] + dist_source_key = None + # Find key by peeking one batch + for b in _scan_batches(): + for cand in vector_key_candidates: + if cand in b and isinstance(b[cand], torch.Tensor) and b[cand].dim() == 2: + dist_source_key = cand + break + if dist_source_key is None: + # fallback: any 2D tensor + for k, v in b.items(): + if isinstance(v, torch.Tensor) and v.dim() == 2: + dist_source_key = k + break + # break after first batch inspected + break + if dist_source_key is None: + raise RuntimeError("Could not find a 2D tensor in test loader batches to compute per-dimension distributions.") + + # Aggregate all control rows for the chosen key + control_rows = [] + for batch in _scan_batches(): + names = _to_list(batch.get("pert_name", [])) + if len(names) == 0: + continue + mask = torch.tensor([str(x) == str(control_pert) for x in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + vec = batch.get(dist_source_key, None) + if isinstance(vec, torch.Tensor) and vec.dim() == 2: + try: + control_rows.append(vec[mask].detach().cpu().float()) + except Exception: + # fallback: take leading rows equal to mask sum + take = int(mask.sum().item()) + if take > 0: + control_rows.append(vec[:take].detach().cpu().float()) + + if len(control_rows) == 0: + raise RuntimeError("No control rows found to compute distributions.") + + control_vectors_all = torch.cat(control_rows, dim=0) # [Nc, D] + D = control_vectors_all.shape[1] + if D != 2000: + logger.warning(f"Expected vector dimension 2000; found {D}. Proceeding with {D} dimensions.") + + control_mean = control_vectors_all.mean(dim=0) + control_std = control_vectors_all.std(dim=0, unbiased=False).clamp_min(1e-8) + + # Save distributions to results directory later; keep in scope for optional shifting + distributions = { + "key": dist_source_key, + "mean": control_mean.numpy(), + "std": control_std.numpy(), + "dim": int(D), + "num_cells": int(control_vectors_all.shape[0]), + } + + def apply_shift_to_core_cells(index: int, upregulate: bool): + """Apply ±2σ shift at a single index across all vectors in core_cells. + + - index: integer in [0, D) + - upregulate: True for +2σ, False for -2σ + Operates in-place on the tensor stored at distributions['key'] inside core_cells. + """ + nonlocal core_cells, distributions + if index < 0 or index >= distributions["dim"]: + raise ValueError(f"Index {index} is out of bounds for dimension {distributions['dim']}") + shift_value = (2.0 if upregulate else -2.0) * float(distributions["std"][index]) + key = distributions["key"] + tensor = core_cells[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + tensor[:, index] = tensor[:, index] + shift_value + core_cells[key] = tensor + + # Optionally apply shift based on CLI flags before running inference + if args.shift_index is not None: + if args.shift_direction is None: + raise ValueError("--shift-direction is required when --shift-index is provided") + apply_shift_to_core_cells(index=int(args.shift_index), upregulate=(args.shift_direction == "up")) + logger.info(f"Applied 2σ {'up' if args.shift_direction=='up' else 'down'} shift at index {int(args.shift_index)} across core_cells") + + # Prepare perturbation ordering and, if needed, buffers for forward passes + perts_order = list(unique_perts) + + if snapshots_only: + logger.info("Heatmap snapshots flag set; skipping phase-one forward passes through the model.") + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + final_preds = None + final_reals = None + final_X_hvg = None + final_pert_cell_counts_preds = None + normal_preds_per_pert = {} + real_preds_per_pert = {} + store_raw_expression = False + else: + logger.info("Generating predictions: one forward pass per perturbation on core_cells...") + num_cells = len(perts_order) * target_core_n + output_dim = var_dims["output_dim"] + gene_dim = var_dims["gene_dim"] + hvg_dim = var_dims["hvg_dim"] + + # Phase 1: Normal inference on all perturbations + final_preds = np.empty((num_cells, output_dim), dtype=np.float32) + final_reals = np.empty((num_cells, output_dim), dtype=np.float32) + + # Phase 2: Store normal predictions for distance computation + normal_preds_per_pert = {} # pert_name -> [64, output_dim] array + real_preds_per_pert = {} # pert_name -> [64, output_dim] array + + store_raw_expression = ( + data_module.embed_key is not None + and data_module.embed_key != "X_hvg" + and cfg["data"]["kwargs"]["output_space"] == "gene" + ) or (data_module.embed_key is not None and cfg["data"]["kwargs"]["output_space"] == "all") + + final_X_hvg = None + final_pert_cell_counts_preds = None + if store_raw_expression: + # Preallocate matrices of shape (num_cells, gene_dim) for decoded predictions. + if cfg["data"]["kwargs"]["output_space"] == "gene": + final_X_hvg = np.empty((num_cells, hvg_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, hvg_dim), dtype=np.float32) + if cfg["data"]["kwargs"]["output_space"] == "all": + final_X_hvg = np.empty((num_cells, gene_dim), dtype=np.float32) + final_pert_cell_counts_preds = np.empty((num_cells, gene_dim), dtype=np.float32) + device = next(model.parameters()).device + + # Prepare perturbation one-hot/embedding map for the pert encoder + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + try: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + else: + logger.warning(f"pert_onehot_map.pt not found at {map_path}; proceeding without explicit pert_emb overrides") + pert_onehot_map = {} + except Exception as e: + logger.warning(f"Failed to load pert_onehot_map.pt: {e}") + pert_onehot_map = {} + + def _prepare_pert_emb(pert_name: str, length: int, device: torch.device): + vec = None + try: + vec = pert_onehot_map.get(pert_name, None) + if vec is None and control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + except Exception: + vec = None + if vec is None: + # Fallback to zeros with model.pert_dim if mapping is unavailable + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("Could not determine pert_dim to build pert_emb") + vec = torch.zeros(pert_dim) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) + + current_idx = 0 + + # Initialize aggregation variables directly + all_pert_names = [] + all_celltypes = [] + all_gem_groups = [] + all_pert_barcodes = [] + all_ctrl_barcodes = [] + + if not snapshots_only: + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="Predicting", unit="pert")): + # Build a batch by copying core_cells and swapping perturbation + batch = {} + for k, v in core_cells.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields to target pert + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + # Best-effort: update any index fields if present and mapping exists + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + batch_preds = model.predict_step(batch, p_idx, padded=False) + + # Extract metadata and data directly from batch_preds + # Handle pert_name + batch_pert_names = [] + if isinstance(batch_preds["pert_name"], list): + all_pert_names.extend(batch_preds["pert_name"]) + batch_pert_names = batch_preds["pert_name"] + else: + all_pert_names.append(batch_preds["pert_name"]) + batch_pert_names = [batch_preds["pert_name"]] + + if "pert_cell_barcode" in batch_preds: + if isinstance(batch_preds["pert_cell_barcode"], list): + all_pert_barcodes.extend(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.extend(batch_preds.get("ctrl_cell_barcode", [None] * len(batch_preds["pert_cell_barcode"])) ) + else: + all_pert_barcodes.append(batch_preds["pert_cell_barcode"]) + all_ctrl_barcodes.append(batch_preds.get("ctrl_cell_barcode", None)) + + # Handle celltype_name + if isinstance(batch_preds["celltype_name"], list): + all_celltypes.extend(batch_preds["celltype_name"]) + else: + all_celltypes.append(batch_preds["celltype_name"]) + + # Handle gem_group + if isinstance(batch_preds["batch"], list): + all_gem_groups.extend([str(x) for x in batch_preds["batch"]]) + elif isinstance(batch_preds["batch"], torch.Tensor): + all_gem_groups.extend([str(x) for x in batch_preds["batch"].cpu().numpy()]) + else: + all_gem_groups.append(str(batch_preds["batch"])) + + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + batch_size = batch_pred_np.shape[0] + final_preds[current_idx : current_idx + batch_size, :] = batch_pred_np + final_reals[current_idx : current_idx + batch_size, :] = batch_real_np + + # Store normal predictions for this perturbation for distance computation + normal_preds_per_pert[pert] = batch_pred_np.copy() + real_preds_per_pert[pert] = batch_real_np.copy() + + current_idx += batch_size + + # Handle X_hvg for HVG space ground truth + if final_X_hvg is not None: + batch_real_gene_np = batch_preds["pert_cell_counts"].cpu().numpy().astype(np.float32) + final_X_hvg[current_idx - batch_size : current_idx, :] = batch_real_gene_np + + # Handle decoded gene predictions if available + if final_pert_cell_counts_preds is not None: + batch_gene_pred_np = batch_preds["pert_cell_counts_preds"].cpu().numpy().astype(np.float32) + final_pert_cell_counts_preds[current_idx - batch_size : current_idx, :] = batch_gene_pred_np + + logger.info("Phase 1 complete: Normal inference on all perturbations.") + + # Phase 2: Run inference with GO MF pathway groups upregulated (only if requested) + run_phase_one_only = phase_one_only or getattr(args, "phase_one_only", False) + + if run_phase_one_only and not snapshots_only: + os.makedirs(results_dir_default, exist_ok=True) + if not snapshots_only: + real_preds_path = os.path.join(results_dir_default, "core_cells_real_preds_per_pert.npy") + try: + np.save(real_preds_path, real_preds_per_pert, allow_pickle=True) + logger.info( + "Saved real perturbed core cell embeddings for %d perturbations to %s", + len(real_preds_per_pert), + real_preds_path, + ) + except Exception as e: + logger.error("Failed to save core cell real predictions to %s: %s", real_preds_path, e) + raise + return + + if args.test_time_heat_map or snapshots_only: + logger.info("Phase 2: Loading GO MF pathway annotations and running pathway-based upregulation...") + + results_dir = results_dir_default + + # Ensure unique heatmap directory per invocation to avoid overwriting prior outputs + timestamp = datetime.utcnow().strftime("%Y%m%dT%H%M%S") + unique_suffix = f"{timestamp}_{uuid.uuid4().hex[:8]}" + heatmap_results_dir = os.path.join(results_dir, "heatmap_runs", unique_suffix) + + os.makedirs(heatmap_results_dir, exist_ok=True) + annotation_ext = os.path.splitext(args.annotation_path)[1].lower() + annotation_source_type = "unknown" + annotation_label = args.annotation_field + field_suffix = ( + (annotation_label or "pathways").replace('_', '').lower() + if (annotation_label or "").strip() + else "pathways" + ) + + # Load gene annotations + import pickle + from collections import defaultdict + + pathway_to_genes = defaultdict(list) + gene_names = var_dims.get("gene_names") + gene_name_to_index = {str(name): idx for idx, name in enumerate(gene_names)} if gene_names is not None else {} + + if annotation_ext == ".json": + annotation_source_type = "json" + annotation_label = os.path.splitext(os.path.basename(args.annotation_path))[0] + field_suffix = annotation_label.replace('_', '').lower() or "pathways" + + with open(args.annotation_path, 'r') as f: + gene_annotations = json.load(f) + + if not isinstance(gene_annotations, dict): + raise ValueError( + f"Expected JSON annotation file {args.annotation_path} to map gene names to pathway collections." + ) + + missing_genes = set() + for gene_name, pathway_data in gene_annotations.items(): + if not pathway_data: + continue + + idx = gene_name_to_index.get(str(gene_name)) + if idx is None: + missing_genes.add(str(gene_name)) + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [] + for entry in pathway_data: + if entry is None: + continue + pathways.append(str(entry).strip()) + pathways = [p for p in pathways if p] + + for pathway in pathways: + pathway_to_genes[pathway].append(idx) + + if missing_genes: + sample_missing = ", ".join(sorted(missing_genes)[:5]) + logger.warning( + "Skipped %d gene(s) from annotation file not present in model gene names (e.g., %s)", + len(missing_genes), + sample_missing, + ) + elif annotation_ext in {".pkl", ".pickle"}: + annotation_source_type = "pickle" + field_suffix = ( + (args.annotation_field or "pathways").replace('_', '').lower() + if (args.annotation_field or "").strip() + else "pathways" + ) + + with open(args.annotation_path, 'rb') as f: + gene_annotations = pickle.load(f) + + if not args.annotation_field: + raise ValueError( + "--annotation-field must be provided when loading pickle annotation files." + ) + + for idx, data in gene_annotations.items(): + pathway_data = None + if isinstance(data, dict): + pathway_data = data.get(args.annotation_field) + else: + try: + pathway_data = data[args.annotation_field] + except (KeyError, TypeError): + pathway_data = getattr(data, args.annotation_field, None) + + if not pathway_data: + continue + + if isinstance(pathway_data, str): + pathways = [p.strip() for p in pathway_data.split(';') if p.strip()] + else: + pathways = [str(p).strip() for p in pathway_data if str(p).strip()] + + try: + gene_index = int(idx) - 1 + except (TypeError, ValueError): + gene_index = gene_name_to_index.get(str(idx)) + + if gene_index is None or gene_index < 0: + continue + + for pathway in pathways: + pathway_to_genes[pathway].append(gene_index) + else: + raise ValueError( + f"Unsupported annotation file extension '{annotation_ext}' for {args.annotation_path}." + ) + + # Filter out pathways with too few genes (less than 3) to avoid noise + filtered_pathways = {pathway: genes for pathway, genes in pathway_to_genes.items() if len(genes) >= 3} + + logger.info( + "Found %d total pathways from annotation source '%s' (%s)", + len(pathway_to_genes), + annotation_label, + annotation_source_type, + ) + logger.info(f"Using {len(filtered_pathways)} pathways with 3+ genes for upregulation") + + # Initialize heatmap array: [num_pathways, num_perturbations] + num_pathways = len(filtered_pathways) + heatmap_distances = np.zeros((num_pathways, len(perts_order)), dtype=np.float32) + pathway_names = list(filtered_pathways.keys()) + + annotation_label_pretty = (annotation_label or "Annotation").replace('_', ' ').strip() + if annotation_label_pretty: + annotation_label_pretty = annotation_label_pretty.title() + else: + annotation_label_pretty = "Annotation" + + upregulated_preds_path = None + upregulated_preds_memmap = None + if not snapshots_only: + if num_pathways == 0: + logger.warning("No pathways passed filtering; skipping upregulated prediction storage.") + else: + try: + upregulated_preds_path = os.path.join( + heatmap_results_dir, + f"{field_suffix}_pathway_upregulated_preds.npy", + ) + upregulated_preds_memmap = np.memmap( + upregulated_preds_path, + dtype=np.float32, + mode="w+", + shape=(num_pathways, len(perts_order), target_core_n, output_dim), + ) + except Exception as e: + logger.warning("Failed to initialize storage for upregulated predictions: %s", e) + upregulated_preds_path = None + upregulated_preds_memmap = None + + # Create a copy of core_cells for upregulation experiments + original_core_cells = _clone_core_cells(core_cells) + + def apply_pathway_shift_to_core_cells(cell_batch: dict, gene_indices: list, upregulate: bool, target_norm: float = 2.0): + """Apply shift to multiple gene indices with equivalent euclidean norm across pathways. + + This function ensures that all pathways receive the same euclidean norm perturbation: + 1. Compute individual shifts based on 2σ for each gene + 2. Calculate the euclidean norm of the shift vector + 3. Rescale the entire shift vector to match the target euclidean norm + + - gene_indices: list of 0-indexed gene positions + - upregulate: True for positive shift, False for negative shift + - target_norm: target euclidean norm for the perturbation (default: 2.0) + Operates in-place on the tensor stored at distributions['key'] inside the provided cell_batch. + """ + nonlocal distributions + key = distributions['key'] + tensor = cell_batch[key] + if not isinstance(tensor, torch.Tensor) or tensor.dim() != 2: + raise RuntimeError(f"Core cell field '{key}' is not a 2D tensor") + + if len(gene_indices) == 0: + return + + # Step 1: Compute raw shift values based on 2σ for each gene + raw_shifts = {} + for idx in gene_indices: + if 0 <= idx < distributions["dim"]: + base_shift = 2.0 * float(distributions["std"][idx]) + raw_shifts[idx] = base_shift if upregulate else -base_shift + + if len(raw_shifts) == 0: + return + + # Step 2: Calculate euclidean norm of the raw shift vector + shift_values = np.array(list(raw_shifts.values())) + current_norm = np.linalg.norm(shift_values) + + # Step 3: Rescale to target norm if current norm > 0 + if current_norm > 1e-8: # Avoid division by zero + scale_factor = target_norm / current_norm + + # Apply rescaled shifts + for idx, raw_shift in raw_shifts.items(): + scaled_shift = raw_shift * scale_factor + tensor[:, idx] = tensor[:, idx] + scaled_shift + else: + # Fallback: if all std deviations are zero, apply uniform shift + uniform_shift = target_norm / np.sqrt(len(raw_shifts)) + for idx in raw_shifts.keys(): + shift_value = uniform_shift if upregulate else -uniform_shift + tensor[:, idx] = tensor[:, idx] + shift_value + + def compute_pathway_core_cell_snapshots(base_core_cells: dict, pathways: dict) -> list: + snapshots = [] + for pathway_name, gene_indices in pathways.items(): + shifted_cells = _clone_core_cells(base_core_cells) + apply_pathway_shift_to_core_cells(shifted_cells, gene_indices, upregulate=True) + snapshots.append( + { + "pathway_name": pathway_name, + "gene_indices": list(gene_indices), + "core_cells": shifted_cells, + } + ) + return snapshots + + shifted_core_cells_path = os.path.join(heatmap_results_dir, f"{field_suffix}_core_cells_upregulated.npy") + pathway_core_cells_snapshots = compute_pathway_core_cell_snapshots(original_core_cells, filtered_pathways) + _save_numpy_snapshot( + pathway_core_cells_snapshots, + shifted_core_cells_path, + description=f"core_cells upregulated snapshots ({len(pathway_core_cells_snapshots)} pathways)", + ) + if len(pathway_core_cells_snapshots) == 0: + logger.warning("No pathway core cell snapshots generated (0 pathways passed filtering).") + + if not snapshots_only: + with torch.no_grad(): + for pathway_idx, snapshot in enumerate( + tqdm(pathway_core_cells_snapshots, desc="Upregulating pathways", unit="pathway") + ): + core_cells_upregulated = snapshot["core_cells"] + # Run inference for all perturbations with this pathway upregulated + for p_idx, pert in enumerate(perts_order): + # Build batch by copying upregulated core_cells + batch = {} + for k, v in core_cells_upregulated.items(): + if isinstance(v, torch.Tensor): + batch[k] = v.clone().to(device) + else: + batch[k] = list(v) + + # Overwrite perturbation fields + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + try: + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + # Ensure perturbation embedding is set for the encoder + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n, device) + + # Get predictions with upregulated pathway + batch_preds = model.predict_step(batch, p_idx, padded=False) + upregulated_preds = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap[pathway_idx, p_idx, :, :] = upregulated_preds + + # Compute euclidean distance between normal and upregulated predictions + normal_preds = normal_preds_per_pert[pert] # [64, output_dim] + distance = np.linalg.norm(upregulated_preds - normal_preds, axis=1).mean() # Mean across 64 cells + heatmap_distances[pathway_idx, p_idx] = distance + + logger.info( + "Phase 2 core cell snapshots ready for %d pathways from annotation source '%s'.", + len(pathway_core_cells_snapshots), + annotation_label or annotation_source_type, + ) + + if snapshots_only: + logger.info("Snapshots-only mode: skipping distance computations, heatmap data, and visualization generation.") + return + + logger.info( + "Phase 2 complete: Upregulated inference for all pathways from annotation source '%s'.", + annotation_label or annotation_source_type, + ) + + if upregulated_preds_memmap is not None: + upregulated_preds_memmap.flush() + logger.info(f"Saved upregulated prediction tensors to {upregulated_preds_path}") + + # Create filename based on annotation field + # Save heatmap data + try: + heatmap_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.npy") + np.save(heatmap_path, heatmap_distances) + + # Save pathway information + pathway_info_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathways_info.json") + pathway_info = { + "pathway_names": pathway_names, + "pathway_to_genes": {pathway: genes for pathway, genes in filtered_pathways.items()}, + "total_pathways": len(pathway_to_genes), + "filtered_pathways": len(filtered_pathways), + "min_genes_per_pathway": 3 + } + with open(pathway_info_path, "w") as f: + json.dump(pathway_info, f, indent=2) + + # Save metadata for the heatmap + heatmap_meta = { + "shape": [num_pathways, len(perts_order)], + "description": ( + f"Euclidean distance heatmap: rows={annotation_label_pretty} pathways, cols=perturbations" + ), + "perturbations": perts_order, + "pathway_names": pathway_names, + "distance_type": "mean_euclidean_norm_across_64_cells", + "upregulation": "equivalent_euclidean_norm_perturbation_rescaled_from_2std_per_gene", + "annotation_field": annotation_label if annotation_source_type != "json" else None, + "annotation_source_type": annotation_source_type, + "upregulated_preds_path": upregulated_preds_path, + } + heatmap_meta_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.meta.json") + with open(heatmap_meta_path, "w") as f: + json.dump(heatmap_meta, f, indent=2) + + logger.info( + "Saved %s pathway upregulation heatmap to %s", + annotation_label_pretty, + heatmap_path, + ) + logger.info(f"Heatmap shape: {heatmap_distances.shape} (pathways x perturbations)") + except Exception as e: + logger.warning(f"Failed to save heatmap data: {e}") + + # Create and save matplotlib heatmap visualization + try: + # Determine output path for heatmap image + if args.heatmap_output_path is not None: + # If user provided a path, make it unique per run as well + base, ext = os.path.splitext(args.heatmap_output_path) + heatmap_img_path = f"{base}_{unique_suffix}{ext or '.png'}" + else: + heatmap_img_path = os.path.join(heatmap_results_dir, f"{field_suffix}_pathway_upregulation_heatmap.png") + + # Ensure directory exists + os.makedirs(os.path.dirname(heatmap_img_path), exist_ok=True) + + # Create the heatmap with appropriate size + fig_width = max(12, len(perts_order) * 0.3) + fig_height = max(8, num_pathways * 0.05) # Smaller height per pathway since we have fewer rows + fig, ax = plt.subplots(figsize=(fig_width, fig_height)) + + # Create heatmap with proper labels + im = ax.imshow(heatmap_distances, cmap='viridis', aspect='auto') + + # Set labels and title + ax.set_xlabel('Perturbations') + ax.set_ylabel(f'{args.annotation_field.replace("_", " ").title()} Pathways') + ax.set_title(f'{args.annotation_field.replace("_", " ").title()} Pathway Upregulation Impact Heatmap\n(Euclidean Distance from Normal Predictions)') + + # Set x-axis labels (perturbations) + ax.set_xticks(range(len(perts_order))) + ax.set_xticklabels(perts_order, rotation=45, ha='right', fontsize=8) + + # Set y-axis labels (pathways) - show pathway names, truncated if too long + ax.set_yticks(range(num_pathways)) + truncated_pathway_names = [] + for pathway_name in pathway_names: + # Remove common prefixes and truncate long names + clean_name = pathway_name + # Remove common GO prefixes + for prefix in ['GOMF_', 'GOCC_', 'GOBP_']: + clean_name = clean_name.replace(prefix, '') + if len(clean_name) > 30: + clean_name = clean_name[:27] + '...' + truncated_pathway_names.append(clean_name) + ax.set_yticklabels(truncated_pathway_names, fontsize=6) + + # Add colorbar + cbar = plt.colorbar(im, ax=ax) + cbar.set_label('Mean Euclidean Distance', rotation=270, labelpad=20) + + # Adjust layout to prevent label cutoff + plt.tight_layout() + + # Save the figure + plt.savefig(heatmap_img_path, dpi=300, bbox_inches='tight') + plt.close(fig) # Close to free memory + + logger.info(f"Saved {args.annotation_field} pathway heatmap visualization to {heatmap_img_path}") + + except Exception as e: + logger.warning(f"Failed to create heatmap visualization: {e}") + else: + logger.info("Skipping heatmap analysis (--test-time-heat-map not set)") + + logger.info("Creating anndatas from predictions from manual loop...") + + # Build pandas DataFrame for obs and var + df_dict = { + data_module.pert_col: all_pert_names, + data_module.cell_type_key: all_celltypes, + data_module.batch_col: all_gem_groups, + } + + if len(all_pert_barcodes) > 0: + df_dict["pert_cell_barcode"] = all_pert_barcodes + df_dict["ctrl_cell_barcode"] = all_ctrl_barcodes + + obs = pd.DataFrame(df_dict) + + gene_names = var_dims["gene_names"] + var = pd.DataFrame({"gene_names": gene_names}) + + if final_X_hvg is not None: + if len(gene_names) != final_pert_cell_counts_preds.shape[1]: + gene_names = np.load( + "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + ) + var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - using the decoded gene expression values + adata_pred = anndata.AnnData(X=final_pert_cell_counts_preds, obs=obs, var=var) + # Create adata for real - using the true gene expression values + adata_real = anndata.AnnData(X=final_X_hvg, obs=obs, var=var) + + # add the embedding predictions + adata_pred.obsm[data_module.embed_key] = final_preds + adata_real.obsm[data_module.embed_key] = final_reals + logger.info(f"Added predicted embeddings to adata.obsm['{data_module.embed_key}']") + else: + # if len(gene_names) != final_preds.shape[1]: + # gene_names = np.load( + # "/large_storage/ctc/userspace/aadduri/datasets/tahoe_19k_to_2k_names.npy", allow_pickle=True + # ) + # var = pd.DataFrame({"gene_names": gene_names}) + + # Create adata for predictions - model was trained on gene expression space already + # adata_pred = anndata.AnnData(X=final_preds, obs=obs, var=var) + adata_pred = anndata.AnnData(X=final_preds, obs=obs) + # Create adata for real - using the true gene expression values + # adata_real = anndata.AnnData(X=final_reals, obs=obs, var=var) + adata_real = anndata.AnnData(X=final_reals, obs=obs) + + # Optionally filter to perturbations seen in at least one training context + if args.shared_only: + try: + shared_perts = data_module.get_shared_perturbations() + if len(shared_perts) == 0: + logger.warning("No shared perturbations between train and test; skipping filtering.") + else: + logger.info( + "Filtering to %d shared perturbations present in train ∩ test.", + len(shared_perts), + ) + mask = adata_pred.obs[data_module.pert_col].isin(shared_perts) + before_n = adata_pred.n_obs + adata_pred = adata_pred[mask].copy() + adata_real = adata_real[mask].copy() + logger.info( + "Filtered cells: %d -> %d (kept only seen perturbations)", + before_n, + adata_pred.n_obs, + ) + except Exception as e: + logger.warning( + "Failed to filter by shared perturbations (%s). Proceeding without filter.", + str(e), + ) + + # Save the AnnData objects + results_dir = results_dir_default + os.makedirs(results_dir, exist_ok=True) + adata_pred_path = os.path.join(results_dir, "adata_pred.h5ad") + adata_real_path = os.path.join(results_dir, "adata_real.h5ad") + + adata_pred.write_h5ad(adata_pred_path) + adata_real.write_h5ad(adata_real_path) + + logger.info(f"Saved adata_pred to {adata_pred_path}") + logger.info(f"Saved adata_real to {adata_real_path}") + + # Save per-dimension control-cell distributions for reproducibility + try: + dist_out = { + "key": distributions["key"], + "dim": distributions["dim"], + "num_cells": distributions["num_cells"], + } + dist_out_path = os.path.join(results_dir, "control_distributions.meta.json") + with open(dist_out_path, "w") as f: + json.dump(dist_out, f) + np.save(os.path.join(results_dir, "control_mean.npy"), distributions["mean"]) # [D] + np.save(os.path.join(results_dir, "control_std.npy"), distributions["std"]) # [D] + logger.info("Saved control-cell per-dimension mean/std distributions") + except Exception as e: + logger.warning(f"Failed to save control-cell distributions: {e}") + + if not args.predict_only: + # 6. Compute metrics using cell-eval + logger.info("Computing metrics using cell-eval...") + + control_pert = data_module.get_control_pert() + + ct_split_real = split_anndata_on_celltype(adata=adata_real, celltype_col=data_module.cell_type_key) + ct_split_pred = split_anndata_on_celltype(adata=adata_pred, celltype_col=data_module.cell_type_key) + + assert len(ct_split_real) == len(ct_split_pred), ( + f"Number of celltypes in real and pred anndata must match: {len(ct_split_real)} != {len(ct_split_pred)}" + ) + + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for ct in ct_split_real.keys(): + real_ct = ct_split_real[ct] + pred_ct = ct_split_pred[ct] + + evaluator = MetricsEvaluator( + adata_pred=pred_ct, + adata_real=real_ct, + control_pert=control_pert, + pert_col=data_module.pert_col, + outdir=results_dir, + prefix=ct, + pdex_kwargs=pdex_kwargs, + batch_size=2048, + ) + + evaluator.compute( + profile=args.profile, + metric_configs={ + "discrimination_score": { + "embed_key": data_module.embed_key, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + "pearson_edistance": { + "embed_key": data_module.embed_key, + "n_jobs": -1, # set to all available cores + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else { + "n_jobs": -1, + }, + } + if data_module.embed_key and data_module.embed_key != "X_hvg" + else {}, + skip_metrics=["pearson_edistance", "clustering_agreement"], + ) + + +def save_core_cells_real_preds(args: ap.ArgumentParser): + """Run only phase one of the heatmap pipeline and persist real core-cell embeddings per perturbation.""" + return run_tx_heatmap(args, phase_one_only=True) diff --git a/src/state/_cli/_tx/test.py b/src/state/_cli/_tx/test.py new file mode 100644 index 00000000..1b8383d0 --- /dev/null +++ b/src/state/_cli/_tx/test.py @@ -0,0 +1,42 @@ +import sys +from pathlib import Path + +import torch +from transformers import LlamaConfig + +REPO_ROOT = Path(__file__).resolve().parents[4] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +# NOTE: keep the import local so we don’t trigger heavy deps for other modules. +from src.state.tx.models.utils import LlamaBidirectionalModel + +# Build a tiny LLaMA config locally so we don’t hit gated HuggingFace weights. +config = LlamaConfig( + hidden_size=128, + intermediate_size=256, + num_hidden_layers=2, + num_attention_heads=4, + num_key_value_heads=4, + max_position_embeddings=128, + vocab_size=52000, +) + +# Force eager attention so we can request attention weights from the forward pass. +config._attn_implementation = "eager" +config.attn_implementation = "eager" + +model = LlamaBidirectionalModel(config) + +# Sanity check: run a tiny forward pass and print one head’s attention. +input_ids = torch.tensor([[1, 2, 3, 4, 5]]) +outputs = model(input_ids, output_attentions=True, return_dict=True) +if outputs.attentions is None: + raise RuntimeError( + "Attention weights were not returned; ensure attention implementation supports `output_attentions`." + ) + +# `outputs.attentions` is a tuple of length n_layers; each is [batch, heads, seq_len, seq_len]. +attn0 = outputs.attentions[0][0, 0] # layer 0, batch 0, head 0 +print("Layer 0, Head 0 attention matrix:\n", attn0) +# You should see non‐zero entries in BOTH upper and lower triangles. \ No newline at end of file diff --git a/src/state/_cli/_tx/test_gpt.py b/src/state/_cli/_tx/test_gpt.py new file mode 100644 index 00000000..1e2e89d7 --- /dev/null +++ b/src/state/_cli/_tx/test_gpt.py @@ -0,0 +1,28 @@ +import sys +from pathlib import Path + +import torch +from transformers import GPT2Config + +REPO_ROOT = Path(__file__).resolve().parents[4] +if str(REPO_ROOT) not in sys.path: + sys.path.insert(0, str(REPO_ROOT)) + +# NOTE: keep the import local so we don’t trigger heavy deps for other modules. +from src.state.tx.models.utils import GPT2BidirectionalModel + +config = GPT2Config.from_pretrained("gpt2") # or load from your own checkpoint +config._attn_implementation = "eager" +config.attn_implementation = "eager" + +model = GPT2BidirectionalModel(config) + +# (or, if you’re using HF’s `from_pretrained` pattern, you can do:) +model = GPT2BidirectionalModel.from_pretrained("gpt2", config=config) + +# Sanity check: run a tiny forward pass and print one head’s attention. +input_ids = torch.tensor([[50256, 314, 617, 198, 198]]) # “Hello” +outputs = model(input_ids, output_attentions=True) +# `outputs.attentions` is a tuple of length n_layers; each is [batch, heads, seq_len, seq_len]. +attn0 = outputs.attentions[0][0, 0] # layer 0, batch 0, head 0 +print("Layer 0, Head 0 attention matrix:\n", attn0) \ No newline at end of file From cb4f4053657118b448efd1f5c986d6b6741ade0d Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Thu, 2 Oct 2025 00:15:11 +0000 Subject: [PATCH 6/9] single --- src/state/__main__.py | 4 + src/state/_cli/__init__.py | 2 + src/state/_cli/_tx/__init__.py | 5 + src/state/_cli/_tx/_double.py | 220 +++++++--- src/state/_cli/_tx/_single.py | 726 +++++++++++++++++++++++++++++++++ 5 files changed, 897 insertions(+), 60 deletions(-) create mode 100644 src/state/_cli/_tx/_single.py diff --git a/src/state/__main__.py b/src/state/__main__.py index 532320b9..7c454036 100755 --- a/src/state/__main__.py +++ b/src/state/__main__.py @@ -12,6 +12,7 @@ run_emb_preprocess, run_emb_eval, run_tx_double, + run_tx_single, run_tx_heatmap, run_tx_infer, run_tx_predict, @@ -129,6 +130,9 @@ def main(): case "double": # Run double perturbation analysis using argparse run_tx_double(args) + case "single": + # Run single perturbation analysis using argparse + run_tx_single(args) case "infer": # Run inference using argparse, similar to predict run_tx_infer(args) diff --git a/src/state/_cli/__init__.py b/src/state/_cli/__init__.py index d1eb55dd..937d304b 100755 --- a/src/state/_cli/__init__.py +++ b/src/state/_cli/__init__.py @@ -2,6 +2,7 @@ from ._tx import ( add_arguments_tx, run_tx_double, + run_tx_single, run_tx_heatmap, run_tx_infer, run_tx_predict, @@ -16,6 +17,7 @@ "run_tx_train", "run_tx_predict", "run_tx_double", + "run_tx_single", "run_tx_heatmap", "run_tx_infer", "run_tx_preprocess_train", diff --git a/src/state/_cli/_tx/__init__.py b/src/state/_cli/_tx/__init__.py index 3e76062f..cb499361 100755 --- a/src/state/_cli/_tx/__init__.py +++ b/src/state/_cli/_tx/__init__.py @@ -6,6 +6,7 @@ from ._predict import add_arguments_predict, run_tx_predict from ._preprocess_infer import add_arguments_preprocess_infer, run_tx_preprocess_infer from ._preprocess_train import add_arguments_preprocess_train, run_tx_preprocess_train +from ._single import add_arguments_single, run_tx_single from ._train import add_arguments_train, run_tx_train __all__ = [ @@ -13,10 +14,13 @@ "run_tx_predict", "run_tx_heatmap", "run_tx_double", + "run_tx_single", "run_tx_infer", "run_tx_preprocess_train", "run_tx_preprocess_infer", "add_arguments_tx", + "add_arguments_double", + "add_arguments_single", ] @@ -27,6 +31,7 @@ def add_arguments_tx(parser: ap.ArgumentParser): add_arguments_predict(subparsers.add_parser("predict")) add_arguments_heatmap(subparsers.add_parser("heatmap")) add_arguments_double(subparsers.add_parser("double")) + add_arguments_single(subparsers.add_parser("single")) add_arguments_infer(subparsers.add_parser("infer")) add_arguments_preprocess_train(subparsers.add_parser("preprocess_train")) add_arguments_preprocess_infer(subparsers.add_parser("preprocess_infer")) diff --git a/src/state/_cli/_tx/_double.py b/src/state/_cli/_tx/_double.py index cf934cd8..926120cc 100644 --- a/src/state/_cli/_tx/_double.py +++ b/src/state/_cli/_tx/_double.py @@ -58,6 +58,20 @@ def add_arguments_double(parser: ap.ArgumentParser) -> None: default=None, help="Directory to save results. Defaults to /eval_.", ) + parser.add_argument( + "--first-pass-only", + action="store_true", + help="Run only the first round of inference and save per-perturbation predictions to a NumPy file.", + ) + parser.add_argument( + "--core-cells-path", + type=str, + default=None, + help=( + "Path to a NumPy .npy file containing a serialized core_cells dictionary to use instead of" + " constructing a new control batch." + ), + ) def run_tx_double(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> None: @@ -81,6 +95,8 @@ def run_tx_double(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> N logging.basicConfig(level=logging.INFO) logger = logging.getLogger(__name__) + phase_one_only = phase_one_only or getattr(args, "first_pass_only", False) + def _prepare_for_serialization(obj): if isinstance(obj, torch.Tensor): return obj.detach().cpu().numpy().copy() @@ -239,15 +255,24 @@ def _resolve_celltype_key(batch, module): else: raise ValueError(f"Unknown model class: {model_name}") + model_init_kwargs = { + "input_dim": var_dims["input_dim"], + "output_dim": var_dims["output_dim"], + "pert_dim": var_dims["pert_dim"], + **model_kwargs, + } + + for optional_key in ("gene_dim", "hvg_dim"): + optional_value = var_dims.get(optional_key) + if optional_value is not None: + model_init_kwargs[optional_key] = optional_value + + if "hidden_dim" in var_dims and "hidden_dim" not in model_init_kwargs: + model_init_kwargs["hidden_dim"] = var_dims["hidden_dim"] + model = ModelClass.load_from_checkpoint( checkpoint_path, - input_dim=var_dims["input_dim"], - hidden_dim=model_kwargs["hidden_dim"], - gene_dim=var_dims.get("gene_dim"), - hvg_dim=var_dims.get("hvg_dim"), - output_dim=var_dims["output_dim"], - pert_dim=var_dims["pert_dim"], - **model_kwargs, + **model_init_kwargs, ) model.eval() logger.info("Model loaded successfully.") @@ -340,7 +365,7 @@ def _generator(): eval_loader = _create_filtered_loader(data_module) logger.info("Test-time fine-tuning complete.") - logger.info("Preparing a fixed batch of 64 control cells and enumerating perturbations...") + logger.info("Preparing core cells batch and enumerating perturbations...") control_pert = data_module.get_control_pert() unique_perts = [] @@ -355,65 +380,140 @@ def _generator(): if not unique_perts: raise RuntimeError("No perturbations found in the provided dataloader.") - eval_loader = _create_filtered_loader(data_module) + target_core_n_default = 256 + custom_core_cells_path = getattr(args, "core_cells_path", None) + + def _load_core_cells_from_path(path): + if not os.path.exists(path): + raise FileNotFoundError(f"Core cells file not found: {path}") + + loaded = np.load(path, allow_pickle=True) + if isinstance(loaded, np.lib.npyio.NpzFile): + raise ValueError("Expected a .npy file containing a serialized dictionary; received an .npz archive.") + if isinstance(loaded, np.ndarray) and loaded.dtype == object: + if loaded.shape == (): + loaded = loaded.item() + if not isinstance(loaded, dict): + raise ValueError( + "Serialized core cells must be a dictionary mapping field names to arrays/tensors;" + f" received type {type(loaded).__name__}." + ) - target_core_n = 64 - accum = {} + converted = {} + inferred_length = None + for key, value in loaded.items(): + if isinstance(value, torch.Tensor): + tensor = value.clone().detach() + elif isinstance(value, np.ndarray): + tensor = torch.from_numpy(value) + else: + tensor = None - def _append_field(store, key, value): - if key not in store: - store[key] = [] - store[key].append(value) + if tensor is not None: + if tensor.dim() == 0: + converted[key] = tensor.item() + continue + if inferred_length is None: + inferred_length = tensor.shape[0] + else: + if tensor.shape[0] != inferred_length: + raise ValueError( + f"Mismatched leading dimensions in core cells: expected {inferred_length}," + f" received {tensor.shape[0]} for key '{key}'." + ) + converted[key] = tensor + else: + values_list = _to_list(value) + if inferred_length is None: + inferred_length = len(values_list) + elif len(values_list) != inferred_length: + raise ValueError( + f"Mismatched list length in core cells: expected {inferred_length}," + f" received {len(values_list)} for key '{key}'." + ) + converted[key] = values_list - for batch in eval_loader: - names = _to_list(batch.get("pert_name", [])) - if not names: - continue - mask = torch.tensor([str(item) == str(control_pert) for item in names], dtype=torch.bool) - if mask.sum().item() == 0: - continue - - current_count = accum.get("_count", 0) - take = min(target_core_n - current_count, int(mask.sum().item())) - if take <= 0: - break - - for key, value in batch.items(): + if inferred_length is None or inferred_length == 0: + raise ValueError("Loaded core cells did not contain any batched entries.") + + for key, value in list(converted.items()): if isinstance(value, torch.Tensor): - mask_device = mask.to(value.device) - selected = value[mask_device][:take].detach().clone() - _append_field(accum, key, selected) + converted[key] = value[:inferred_length].clone() else: - vals = _to_list(value) - selected_vals = [vals[idx] for idx, keep in enumerate(mask.tolist()) if keep][:take] - _append_field(accum, key, selected_vals) + converted[key] = value[:inferred_length] - accum["_count"] = current_count + take - if accum["_count"] >= target_core_n: - break + return converted, inferred_length - if accum.get("_count", 0) < target_core_n: - raise RuntimeError( - f"Could not assemble {target_core_n} control cells; gathered only {accum.get('_count', 0)}." + if custom_core_cells_path: + core_cells, target_core_n = _load_core_cells_from_path(os.path.abspath(custom_core_cells_path)) + logger.info( + "Loaded custom core_cells batch with size %d from %s.", + target_core_n, + custom_core_cells_path, ) + else: + eval_loader = _create_filtered_loader(data_module) + target_core_n = target_core_n_default + accum = {} - core_cells = {} - for key, parts in accum.items(): - if key == "_count": - continue - if len(parts) == 1: - val = parts[0] - else: - if isinstance(parts[0], torch.Tensor): - val = torch.cat(parts, dim=0) + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + if not names: + continue + mask = torch.tensor([str(item) == str(control_pert) for item in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + + current_count = accum.get("_count", 0) + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + mask_device = mask.to(value.device) + selected = value[mask_device][:take].detach().clone() + _append_field(accum, key, selected) + else: + vals = _to_list(value) + selected_vals = [vals[idx] for idx, keep in enumerate(mask.tolist()) if keep][:take] + _append_field(accum, key, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError( + f"Could not assemble {target_core_n} control cells; gathered only {accum.get('_count', 0)}." + ) + + core_cells = {} + for key, parts in accum.items(): + if key == "_count": + continue + if len(parts) == 1: + val = parts[0] else: - merged = [] - for part in parts: - merged.extend(_to_list(part)) - val = merged - core_cells[key] = val[:target_core_n] if isinstance(val, torch.Tensor) else _to_list(val)[:target_core_n] + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for part in parts: + merged.extend(_to_list(part)) + val = merged + core_cells[key] = ( + val[:target_core_n] + if isinstance(val, torch.Tensor) + else _to_list(val)[:target_core_n] + ) - logger.info("Constructed core_cells batch with size %d.", target_core_n) + logger.info("Constructed core_cells batch with size %d.", target_core_n) os.makedirs(results_dir_default, exist_ok=True) baseline_path = os.path.join(results_dir_default, "core_cells_baseline.npy") @@ -532,12 +632,12 @@ def _prepare_pert_emb(pert_name, length): logger.info("First pass complete across %d perturbations.", num_perts) if phase_one_only: - real_preds_path = os.path.join(results_dir_default, "core_cells_real_preds_per_pert.npy") - np.save(real_preds_path, first_pass_real, allow_pickle=True) + preds_path = os.path.join(results_dir_default, "first_pass_preds.npy") + np.save(preds_path, first_pass_preds, allow_pickle=True) logger.info( - "Saved real perturbed embeddings for %d perturbations to %s", + "Saved first-pass predictions for %d perturbations to %s", num_perts, - real_preds_path, + preds_path, ) return diff --git a/src/state/_cli/_tx/_single.py b/src/state/_cli/_tx/_single.py new file mode 100644 index 00000000..941c80de --- /dev/null +++ b/src/state/_cli/_tx/_single.py @@ -0,0 +1,726 @@ +import argparse as ap + + +def add_arguments_single(parser: ap.ArgumentParser) -> None: + """CLI for single-pass perturbation analysis on a target cell line.""" + + parser.add_argument( + "--output-dir", + type=str, + required=True, + help=( + "Path to the output_dir containing the config.yaml file that was saved during training." + ), + ) + parser.add_argument( + "--checkpoint", + type=str, + default="last.ckpt", + help="Checkpoint filename relative to the output directory (default: last.ckpt).", + ) + parser.add_argument( + "--profile", + type=str, + default="full", + choices=["full", "minimal", "de", "anndata"], + help="Evaluation profile to run after inference.", + ) + parser.add_argument( + "--predict-only", + action="store_true", + help="Skip metric computation and only run inference.", + ) + parser.add_argument( + "--shared-only", + action="store_true", + help="Restrict outputs to perturbations present in both train and test sets.", + ) + parser.add_argument( + "--eval-train-data", + action="store_true", + help="Evaluate the model on the training data instead of the test data.", + ) + parser.add_argument( + "--target-cell-type", + type=str, + required=True, + help="Cell type to construct the base core cells for single perturbations.", + ) + parser.add_argument( + "--results-dir", + type=str, + default=None, + help="Directory to save results. Defaults to /eval_.", + ) + parser.add_argument( + "--first-pass-only", + action="store_true", + help="Run only the first round of inference and save per-perturbation predictions to a NumPy file.", + ) + parser.add_argument( + "--core-cells-path", + type=str, + default=None, + help=( + "Path to a NumPy .npy file containing a serialized core_cells dictionary to use instead of" + " constructing a new control batch." + ), + ) + + +def run_tx_single(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> None: + import logging + import os + import sys + + import anndata + import lightning.pytorch as pl + import numpy as np + import pandas as pd + import torch + import yaml + from tqdm import tqdm + + from cell_eval import MetricsEvaluator + from cell_eval.utils import split_anndata_on_celltype + from cell_load.data_modules import PerturbationDataModule + + logging.basicConfig(level=logging.INFO) + logger = logging.getLogger(__name__) + + phase_one_only = phase_one_only or getattr(args, "first_pass_only", False) + + def _prepare_for_serialization(obj): + if isinstance(obj, torch.Tensor): + return obj.detach().cpu().numpy().copy() + if isinstance(obj, dict): + return {k: _prepare_for_serialization(v) for k, v in obj.items()} + if isinstance(obj, (list, tuple)): + return [_prepare_for_serialization(v) for v in obj] + return obj + + def _save_numpy_snapshot(obj, path, description=None): + serializable = _prepare_for_serialization(obj) + try: + np.save(path, serializable, allow_pickle=True) + if description: + logger.info("Saved %s to %s", description, path) + else: + logger.info("Saved snapshot to %s", path) + except Exception as exc: + logger.warning("Failed to save %s to %s: %s", description or "snapshot", path, exc) + + def _to_list(value): + if isinstance(value, list): + return value + if isinstance(value, torch.Tensor): + try: + return [x.item() if x.dim() == 0 else x for x in value] + except Exception: + return value.tolist() + if isinstance(value, (tuple, set)): + return list(value) + if value is None: + return [] + return [value] + + def _normalize_field(values, length, filler=None): + items = list(_to_list(values)) + if len(items) == 1 and length > 1: + items = items * length + if len(items) < length: + items.extend([filler] * (length - len(items))) + elif len(items) > length: + items = items[:length] + return items + + def _resolve_celltype_key(batch, module): + candidate_keys = [] + base_key = getattr(module, "cell_type_key", None) + if base_key: + candidate_keys.append(base_key) + alias_keys = getattr(module, "cell_type_key_aliases", None) + if isinstance(alias_keys, (list, tuple)): + candidate_keys.extend(alias_keys) + alias_keys_alt = getattr(module, "celltype_key_aliases", None) + if isinstance(alias_keys_alt, (list, tuple)): + candidate_keys.extend(alias_keys_alt) + candidate_keys.extend([ + "celltype_name", + "cell_type", + "celltype", + "cell_line", + ]) + seen = set() + ordered_candidates = [] + for key in candidate_keys: + if not key or key in seen: + continue + seen.add(key) + ordered_candidates.append(key) + if key in batch: + return key, ordered_candidates + return None, ordered_candidates + + torch.multiprocessing.set_sharing_strategy("file_system") + + config_path = os.path.join(args.output_dir, "config.yaml") + if not os.path.exists(config_path): + raise FileNotFoundError(f"Could not find config file: {config_path}") + with open(config_path, "r", encoding="utf-8") as file: + cfg = yaml.safe_load(file) + logger.info("Loaded config from %s", config_path) + + run_output_dir = os.path.join(cfg["output_dir"], cfg["name"]) + if not os.path.isabs(run_output_dir): + run_output_dir = os.path.abspath(run_output_dir) + if not os.path.exists(run_output_dir): + inferred_run_dir = args.output_dir + if os.path.exists(inferred_run_dir): + logger.warning( + "Run directory %s not found; falling back to config directory %s", + run_output_dir, + inferred_run_dir, + ) + run_output_dir = inferred_run_dir + else: + raise FileNotFoundError( + "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) + ) + + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + data_module = PerturbationDataModule.load_state(data_module_path) + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + + pl.seed_everything(cfg["training"]["train_seed"]) + + checkpoint_dir = os.path.join(run_output_dir, "checkpoints") + checkpoint_path = os.path.join(checkpoint_dir, args.checkpoint) + if not os.path.exists(checkpoint_path): + raise FileNotFoundError( + f"Could not find checkpoint at {checkpoint_path}. Specify --checkpoint with a valid file." + ) + logger.info("Loading model from %s", checkpoint_path) + + model_name = cfg["model"]["name"] + model_kwargs = cfg["model"]["kwargs"] + var_dims = data_module.get_var_dims() + + if model_name.lower() == "embedsum": + from ...tx.models.embed_sum import EmbedSumPerturbationModel + + ModelClass = EmbedSumPerturbationModel + elif model_name.lower() == "old_neuralot": + from ...tx.models.old_neural_ot import OldNeuralOTPerturbationModel + + ModelClass = OldNeuralOTPerturbationModel + elif model_name.lower() in {"neuralot", "pertsets", "state"}: + from ...tx.models.state_transition import StateTransitionPerturbationModel + + ModelClass = StateTransitionPerturbationModel + elif model_name.lower() in {"globalsimplesum", "perturb_mean"}: + from ...tx.models.perturb_mean import PerturbMeanPerturbationModel + + ModelClass = PerturbMeanPerturbationModel + elif model_name.lower() in {"celltypemean", "context_mean"}: + from ...tx.models.context_mean import ContextMeanPerturbationModel + + ModelClass = ContextMeanPerturbationModel + elif model_name.lower() == "decoder_only": + from ...tx.models.decoder_only import DecoderOnlyPerturbationModel + + ModelClass = DecoderOnlyPerturbationModel + else: + raise ValueError(f"Unknown model class: {model_name}") + + model_init_kwargs = { + "input_dim": var_dims["input_dim"], + "output_dim": var_dims["output_dim"], + "pert_dim": var_dims["pert_dim"], + **model_kwargs, + } + + for optional_key in ("gene_dim", "hvg_dim"): + optional_value = var_dims.get(optional_key) + if optional_value is not None: + model_init_kwargs[optional_key] = optional_value + + if "hidden_dim" in var_dims and "hidden_dim" not in model_init_kwargs: + model_init_kwargs["hidden_dim"] = var_dims["hidden_dim"] + + model = ModelClass.load_from_checkpoint( + checkpoint_path, + **model_init_kwargs, + ) + model.eval() + logger.info("Model loaded successfully.") + + results_dir_default = ( + args.results_dir + if args.results_dir is not None + else os.path.join(args.output_dir, f"eval_{os.path.basename(args.checkpoint)}") + ) + + data_module.batch_size = 1 + target_celltype = getattr(args, "target_cell_type") + + def _create_filtered_loader(module): + base_loader = ( + module.train_dataloader(test=True) + if args.eval_train_data + else module.test_dataloader() + ) + + celltype_key, attempted = _resolve_celltype_key({}, module) + + def _generator(): + found_target = False + for batch in base_loader: + if target_celltype is None: + found_target = True + yield batch + continue + + key = celltype_key + if key is None: + key, attempted_keys = _resolve_celltype_key(batch, module) + if key is None: + available_keys = [k for k in batch.keys() if isinstance(k, str)] + available_preview = ", ".join(sorted(available_keys)[:10]) + raise ValueError( + "--target-cell-type requested filtering but none of the expected keys (%s) were present." + " Available batch keys: %s%s" + % ( + ", ".join(attempted_keys) if attempted_keys else "none", + available_preview, + "..." if len(available_keys) > 10 else "", + ) + ) + + celltypes = _to_list(batch[key]) + mask_values = [str(ct).lower() == target_celltype.lower() for ct in celltypes] + if not mask_values or not any(mask_values): + continue + + mask = torch.tensor(mask_values, dtype=torch.bool) + filtered = {} + for batch_key, value in batch.items(): + if isinstance(value, torch.Tensor): + mask_device = mask.to(value.device) + selected = value[mask_device] + if selected.shape[0] == 0: + continue + filtered[batch_key] = selected + else: + vals = _to_list(value) + selected = [vals[idx] for idx, keep in enumerate(mask_values) if keep] + if not selected: + continue + filtered[batch_key] = selected + if filtered: + found_target = True + yield filtered + + if target_celltype and not found_target: + raise ValueError( + f"Target cell type '{target_celltype}' not found in any batches for evaluation." + ) + + return _generator() + + eval_loader = _create_filtered_loader(data_module) + + logger.info("Preparing core cells batch and enumerating perturbations...") + + control_pert = data_module.get_control_pert() + unique_perts = [] + seen_perts = set() + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + for name in names: + name_value = name.item() if isinstance(name, torch.Tensor) else str(name) + if name_value not in seen_perts: + seen_perts.add(name_value) + unique_perts.append(name_value) + if not unique_perts: + raise RuntimeError("No perturbations found in the provided dataloader.") + + target_core_n_default = 256 + custom_core_cells_path = getattr(args, "core_cells_path", None) + + def _load_core_cells_from_path(path): + if not os.path.exists(path): + raise FileNotFoundError(f"Core cells file not found: {path}") + + loaded = np.load(path, allow_pickle=True) + if isinstance(loaded, np.lib.npyio.NpzFile): + raise ValueError("Expected a .npy file containing a serialized dictionary; received an .npz archive.") + if isinstance(loaded, np.ndarray) and loaded.dtype == object: + if loaded.shape == (): + loaded = loaded.item() + if not isinstance(loaded, dict): + raise ValueError( + "Serialized core cells must be a dictionary mapping field names to arrays/tensors;" + f" received type {type(loaded).__name__}." + ) + + converted = {} + inferred_length = None + for key, value in loaded.items(): + if isinstance(value, torch.Tensor): + tensor = value.clone().detach() + elif isinstance(value, np.ndarray): + tensor = torch.from_numpy(value) + else: + tensor = None + + if tensor is not None: + if tensor.dim() == 0: + converted[key] = tensor.item() + continue + if inferred_length is None: + inferred_length = tensor.shape[0] + else: + if tensor.shape[0] != inferred_length: + raise ValueError( + f"Mismatched leading dimensions in core cells: expected {inferred_length}," + f" received {tensor.shape[0]} for key '{key}'." + ) + converted[key] = tensor + else: + values_list = _to_list(value) + if inferred_length is None: + inferred_length = len(values_list) + elif len(values_list) != inferred_length: + raise ValueError( + f"Mismatched list length in core cells: expected {inferred_length}," + f" received {len(values_list)} for key '{key}'." + ) + converted[key] = values_list + + if inferred_length is None or inferred_length == 0: + raise ValueError("Loaded core cells did not contain any batched entries.") + + for key, value in list(converted.items()): + if isinstance(value, torch.Tensor): + converted[key] = value[:inferred_length].clone() + else: + converted[key] = value[:inferred_length] + + return converted, inferred_length + + if custom_core_cells_path: + core_cells, target_core_n = _load_core_cells_from_path(os.path.abspath(custom_core_cells_path)) + logger.info( + "Loaded custom core_cells batch with size %d from %s.", + target_core_n, + custom_core_cells_path, + ) + else: + eval_loader = _create_filtered_loader(data_module) + target_core_n = target_core_n_default + accum = {} + + def _append_field(store, key, value): + if key not in store: + store[key] = [] + store[key].append(value) + + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + if not names: + continue + mask = torch.tensor([str(item) == str(control_pert) for item in names], dtype=torch.bool) + if mask.sum().item() == 0: + continue + + current_count = accum.get("_count", 0) + take = min(target_core_n - current_count, int(mask.sum().item())) + if take <= 0: + break + + for key, value in batch.items(): + if isinstance(value, torch.Tensor): + mask_device = mask.to(value.device) + selected = value[mask_device][:take].detach().clone() + _append_field(accum, key, selected) + else: + vals = _to_list(value) + selected_vals = [vals[idx] for idx, keep in enumerate(mask.tolist()) if keep][:take] + _append_field(accum, key, selected_vals) + + accum["_count"] = current_count + take + if accum["_count"] >= target_core_n: + break + + if accum.get("_count", 0) < target_core_n: + raise RuntimeError( + f"Could not assemble {target_core_n} control cells; gathered only {accum.get('_count', 0)}." + ) + + core_cells = {} + for key, parts in accum.items(): + if key == "_count": + continue + if len(parts) == 1: + val = parts[0] + else: + if isinstance(parts[0], torch.Tensor): + val = torch.cat(parts, dim=0) + else: + merged = [] + for part in parts: + merged.extend(_to_list(part)) + val = merged + core_cells[key] = ( + val[:target_core_n] + if isinstance(val, torch.Tensor) + else _to_list(val)[:target_core_n] + ) + + logger.info("Constructed core_cells batch with size %d.", target_core_n) + + os.makedirs(results_dir_default, exist_ok=True) + baseline_path = os.path.join(results_dir_default, "core_cells_baseline.npy") + _save_numpy_snapshot(core_cells, baseline_path, "baseline core_cells batch") + + perts_order = list(unique_perts) + num_perts = len(perts_order) + output_dim = var_dims["output_dim"] + gene_dim = var_dims.get("gene_dim", 0) + hvg_dim = var_dims.get("hvg_dim", 0) + + logger.info("Running first-pass predictions across %d perturbations...", num_perts) + + first_pass_preds = np.empty((num_perts, target_core_n, output_dim), dtype=np.float32) + first_pass_real = np.empty((num_perts, target_core_n, output_dim), dtype=np.float32) + + embed_key = getattr(data_module, "embed_key", None) or "latent_embedding" + output_space = cfg["data"]["kwargs"].get("output_space", "embedding") + store_counts = output_space in {"gene", "all"} + + first_pass_counts = None + first_pass_counts_pred = None + if store_counts: + feature_dim = hvg_dim if output_space == "gene" and hvg_dim else gene_dim + if feature_dim > 0: + first_pass_counts = np.empty((num_perts, target_core_n, feature_dim), dtype=np.float32) + first_pass_counts_pred = np.empty((num_perts, target_core_n, feature_dim), dtype=np.float32) + else: + store_counts = False + + metadata = { + "pert_name": [], + "celltype_name": [], + "batch": [], + "pert_cell_barcode": [], + "ctrl_cell_barcode": [], + } + + device = next(model.parameters()).device + + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + else: + logger.warning("pert_onehot_map.pt not found at %s; proceeding with zero embeddings", map_path) + pert_onehot_map = {} + + def _prepare_pert_emb(pert_name, length): + vec = pert_onehot_map.get(pert_name) + if vec is None and control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + if vec is None: + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("pert_dim is undefined; cannot create perturbation embedding") + vec = torch.zeros(pert_dim) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) + + with torch.no_grad(): + for p_idx, pert in enumerate(tqdm(perts_order, desc="First pass", unit="pert")): + batch = {} + for key, value in core_cells.items(): + if isinstance(value, torch.Tensor): + batch[key] = value.clone().to(device) + else: + batch[key] = list(value) + + if "pert_name" in batch: + batch["pert_name"] = [pert for _ in range(target_core_n)] + if "pert_idx" in batch and hasattr(data_module, "get_pert_index"): + try: + idx_val = int(data_module.get_pert_index(pert)) + batch["pert_idx"] = torch.tensor([idx_val] * target_core_n, device=device) + except Exception: + pass + + batch["pert_emb"] = _prepare_pert_emb(pert, target_core_n) + + batch_preds = model.predict_step(batch, p_idx, padded=False) + + batch_size = batch_preds["preds"].shape[0] + metadata["pert_name"].extend(_normalize_field(batch_preds.get("pert_name", pert), batch_size, pert)) + metadata["celltype_name"].extend( + _normalize_field(batch_preds.get("celltype_name"), batch_size, target_celltype) + ) + metadata["batch"].extend( + [None if b is None else str(b) for b in _normalize_field(batch_preds.get("batch"), batch_size)] + ) + metadata["pert_cell_barcode"].extend( + _normalize_field(batch_preds.get("pert_cell_barcode"), batch_size) + ) + metadata["ctrl_cell_barcode"].extend( + _normalize_field(batch_preds.get("ctrl_cell_barcode"), batch_size) + ) + + batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + + first_pass_preds[p_idx, :, :] = batch_pred_np + first_pass_real[p_idx, :, :] = batch_real_np + + if store_counts and first_pass_counts is not None and batch_preds.get("pert_cell_counts") is not None: + counts_np = batch_preds["pert_cell_counts"].detach().cpu().numpy().astype(np.float32) + first_pass_counts[p_idx, :, :] = counts_np + + if ( + store_counts + and first_pass_counts_pred is not None + and batch_preds.get("pert_cell_counts_preds") is not None + ): + counts_pred_np = batch_preds["pert_cell_counts_preds"].detach().cpu().numpy().astype(np.float32) + first_pass_counts_pred[p_idx, :, :] = counts_pred_np + + logger.info("First pass complete across %d perturbations.", num_perts) + + metadata_df = pd.DataFrame(metadata) + if metadata_df.empty: + raise RuntimeError("No metadata collected during first pass; cannot proceed.") + + pert_col = getattr(data_module, "pert_col", None) or "perturbation" + cell_type_col = getattr(data_module, "cell_type_key", None) or "cell_type" + batch_col = getattr(data_module, "batch_col", None) or "batch" + + obs_df = pd.DataFrame( + { + pert_col: metadata_df["pert_name"], + cell_type_col: metadata_df["celltype_name"], + batch_col: metadata_df["batch"], + } + ) + if metadata_df["pert_cell_barcode"].notna().any(): + obs_df["pert_cell_barcode"] = metadata_df["pert_cell_barcode"] + if metadata_df["ctrl_cell_barcode"].notna().any(): + obs_df["ctrl_cell_barcode"] = metadata_df["ctrl_cell_barcode"] + + first_pass_pred_flat = first_pass_preds.reshape(num_perts * target_core_n, output_dim) + first_pass_real_flat = first_pass_real.reshape(num_perts * target_core_n, output_dim) + + if store_counts and first_pass_counts is not None and first_pass_counts_pred is not None: + feature_dim = first_pass_counts.shape[-1] + gene_names = var_dims.get("gene_names") + if gene_names is not None and len(gene_names) == feature_dim: + var_index = pd.Index([str(name) for name in gene_names], name="gene") + else: + var_index = pd.Index([f"feature_{idx}" for idx in range(feature_dim)], name="feature") + var_df = pd.DataFrame(index=var_index) + + pred_X = first_pass_counts_pred.reshape(num_perts * target_core_n, feature_dim) + real_X = first_pass_counts.reshape(num_perts * target_core_n, feature_dim) + else: + var_index = pd.Index([f"embedding_{idx}" for idx in range(output_dim)], name="embedding") + var_df = pd.DataFrame(index=var_index) + pred_X = first_pass_pred_flat + real_X = first_pass_real_flat + + first_pass_pred_adata = anndata.AnnData(X=pred_X, obs=obs_df.copy(), var=var_df.copy()) + first_pass_real_adata = anndata.AnnData(X=real_X, obs=obs_df.copy(), var=var_df.copy()) + first_pass_pred_adata.obsm[embed_key] = first_pass_pred_flat + first_pass_real_adata.obsm[embed_key] = first_pass_real_flat + + first_pass_pred_path = os.path.join(results_dir_default, "first_pass_preds.h5ad") + first_pass_real_path = os.path.join(results_dir_default, "first_pass_real.h5ad") + first_pass_pred_adata.write_h5ad(first_pass_pred_path) + first_pass_real_adata.write_h5ad(first_pass_real_path) + logger.info("Saved first-pass predicted adata to %s", first_pass_pred_path) + logger.info("Saved first-pass real adata to %s", first_pass_real_path) + + np.save(os.path.join(results_dir_default, "first_pass_preds.npy"), first_pass_preds) + np.save(os.path.join(results_dir_default, "first_pass_real.npy"), first_pass_real) + if first_pass_counts is not None: + np.save(os.path.join(results_dir_default, "first_pass_counts.npy"), first_pass_counts) + if first_pass_counts_pred is not None: + np.save(os.path.join(results_dir_default, "first_pass_counts_pred.npy"), first_pass_counts_pred) + + if phase_one_only: + logger.info("Phase one complete; skipping metrics as requested.") + return + + if args.predict_only: + return + + if cell_type_col not in first_pass_real_adata.obs.columns: + logger.warning( + "Cell type column '%s' not found in observations; skipping metric computation.", + cell_type_col, + ) + return + + control_pert = data_module.get_control_pert() + ct_split_real = split_anndata_on_celltype( + adata=first_pass_real_adata, + celltype_col=cell_type_col, + ) + ct_split_pred = split_anndata_on_celltype( + adata=first_pass_pred_adata, + celltype_col=cell_type_col, + ) + + if len(ct_split_real) != len(ct_split_pred): + logger.warning( + "Number of celltypes in real and predicted AnnData objects differ (%d vs %d); skipping metrics.", + len(ct_split_real), + len(ct_split_pred), + ) + return + + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) + for celltype in ct_split_real.keys(): + real_ct = ct_split_real[celltype] + pred_ct = ct_split_pred[celltype] + + metric_configs = {} + if data_module.embed_key and data_module.embed_key != "X_hvg": + metric_configs = { + "discrimination_score": {"embed_key": embed_key}, + "pearson_edistance": {"embed_key": embed_key, "n_jobs": -1}, + } + else: + metric_configs = {"pearson_edistance": {"n_jobs": -1}} + + evaluator = MetricsEvaluator( + adata_pred=pred_ct, + adata_real=real_ct, + control_pert=control_pert, + pert_col=pert_col, + outdir=results_dir_default, + prefix=str(celltype), + pdex_kwargs=pdex_kwargs, + batch_size=2048, + ) + evaluator.compute( + profile=args.profile, + metric_configs=metric_configs, + skip_metrics=["pearson_edistance", "clustering_agreement"], + ) + + +def save_core_cells_real_preds(args: ap.ArgumentParser) -> None: + """Run only phase one of the pipeline and persist real core-cell embeddings per perturbation.""" + return run_tx_single(args, phase_one_only=True) From e182430279f34b0ffeb1bcfcedd2aef41ec4c8d9 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Thu, 2 Oct 2025 02:55:39 +0000 Subject: [PATCH 7/9] single fix --- src/state/_cli/_tx/_single.py | 234 +++++++++++++++++++++++++++++----- 1 file changed, 199 insertions(+), 35 deletions(-) diff --git a/src/state/_cli/_tx/_single.py b/src/state/_cli/_tx/_single.py index 941c80de..53c39cf3 100644 --- a/src/state/_cli/_tx/_single.py +++ b/src/state/_cli/_tx/_single.py @@ -38,13 +38,16 @@ def add_arguments_single(parser: ap.ArgumentParser) -> None: parser.add_argument( "--eval-train-data", action="store_true", - help="Evaluate the model on the training data instead of the test data.", + help="Evaluate the model on the training data instead of the test data (ignored with --core-cells-path).", ) parser.add_argument( "--target-cell-type", type=str, - required=True, - help="Cell type to construct the base core cells for single perturbations.", + default=None, + help=( + "Optional cell type to construct the base core cells for single perturbations." + " Ignored when --core-cells-path is provided." + ), ) parser.add_argument( "--results-dir", @@ -66,6 +69,14 @@ def add_arguments_single(parser: ap.ArgumentParser) -> None: " constructing a new control batch." ), ) + parser.add_argument( + "--data-config", + type=str, + default=None, + help=( + "Path to a TOML data configuration file to override the data paths in the loaded data module." + ), + ) def run_tx_single(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> None: @@ -188,12 +199,48 @@ def _resolve_celltype_key(batch, module): "Could not resolve run directory. Checked: %s and %s" % (run_output_dir, inferred_run_dir) ) - data_module_path = os.path.join(run_output_dir, "data_module.torch") - if not os.path.exists(data_module_path): - raise FileNotFoundError(f"Could not find data module at {data_module_path}") - data_module = PerturbationDataModule.load_state(data_module_path) - data_module.setup(stage="test") - logger.info("Loaded data module from %s", data_module_path) + # Override data paths if --data-config is provided + # We need to modify the saved state before loading the data module + if getattr(args, "data_config", None) is not None: + import tempfile + + logger.info("Overriding data paths with config from %s", args.data_config) + + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + + # Load the raw state dict + state_dict = torch.load(data_module_path, weights_only=False) + + # Override the toml_config_path in the state dict + state_dict['toml_config_path'] = os.path.abspath(args.data_config) + logger.info("Updated toml_config_path to: %s", state_dict['toml_config_path']) + + # Save to a temporary location in a writable directory + with tempfile.NamedTemporaryFile(mode='wb', suffix='.torch', delete=False) as tmp: + temp_path = tmp.name + torch.save(state_dict, temp_path) + logger.info("Saved modified state to temp file: %s", temp_path) + + # Load the data module from the modified state + data_module = PerturbationDataModule.load_state(temp_path) + + # Clean up temp file + os.remove(temp_path) + else: + data_module_path = os.path.join(run_output_dir, "data_module.torch") + if not os.path.exists(data_module_path): + raise FileNotFoundError(f"Could not find data module at {data_module_path}") + data_module = PerturbationDataModule.load_state(data_module_path) + + # Only setup data module if we need to load cell data (not using preassembled core cells) + custom_core_cells_path = getattr(args, "core_cells_path", None) + if custom_core_cells_path is None: + data_module.setup(stage="test") + logger.info("Loaded data module from %s", data_module_path) + else: + logger.info("Loaded data module configuration from %s (skipping data setup for preassembled core cells)", data_module_path) pl.seed_everything(cfg["training"]["train_seed"]) @@ -207,7 +254,43 @@ def _resolve_celltype_key(batch, module): model_name = cfg["model"]["name"] model_kwargs = cfg["model"]["kwargs"] - var_dims = data_module.get_var_dims() + + # Get var_dims from checkpoint when using preassembled core cells to avoid loading data + if custom_core_cells_path is not None: + logger.info("Reading var_dims from checkpoint to avoid loading data") + ckpt_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu') + var_dims = {} + + # Extract dimensions from model state dict + if 'state_dict' in ckpt_state: + state_dict = ckpt_state['state_dict'] + # Try to infer dimensions from model weights + for key in state_dict.keys(): + if 'pert_emb' in key or 'perturbation_encoder' in key: + if 'weight' in key: + var_dims['pert_dim'] = state_dict[key].shape[0] if state_dict[key].dim() > 1 else state_dict[key].shape[0] + if 'cell_encoder' in key or 'encoder.0' in key: + if 'weight' in key and 'input_dim' not in var_dims: + var_dims['input_dim'] = state_dict[key].shape[1] if state_dict[key].dim() > 1 else state_dict[key].shape[0] + if 'decoder' in key or 'cell_decoder' in key: + if 'weight' in key and 'output_dim' not in var_dims: + # Output dim is usually the last layer of decoder + if 'decoder.weight' in key or 'cell_decoder.weight' in key: + var_dims['output_dim'] = state_dict[key].shape[0] + + # Also check hyper_parameters in checkpoint + if 'hyper_parameters' in ckpt_state: + hp = ckpt_state['hyper_parameters'] + for dim_key in ['input_dim', 'output_dim', 'pert_dim', 'gene_dim', 'hvg_dim', 'hidden_dim']: + if dim_key in hp: + var_dims[dim_key] = hp[dim_key] + + logger.info("Extracted var_dims from checkpoint: %s", var_dims) + + if not var_dims or 'input_dim' not in var_dims or 'output_dim' not in var_dims or 'pert_dim' not in var_dims: + raise RuntimeError(f"Could not extract required dimensions from checkpoint. Got: {var_dims}") + else: + var_dims = data_module.get_var_dims() if model_name.lower() == "embedsum": from ...tx.models.embed_sum import EmbedSumPerturbationModel @@ -266,11 +349,17 @@ def _resolve_celltype_key(batch, module): data_module.batch_size = 1 target_celltype = getattr(args, "target_cell_type") + # custom_core_cells_path was already loaded earlier to skip data setup + if custom_core_cells_path is not None: + target_celltype = None def _create_filtered_loader(module): + use_preassembled_core = custom_core_cells_path is not None + eval_train = bool(getattr(args, "eval_train_data", False)) and not use_preassembled_core + base_loader = ( module.train_dataloader(test=True) - if args.eval_train_data + if eval_train else module.test_dataloader() ) @@ -279,6 +368,11 @@ def _create_filtered_loader(module): def _generator(): found_target = False for batch in base_loader: + if use_preassembled_core: + found_target = True + yield batch + continue + if target_celltype is None: found_target = True yield batch @@ -331,25 +425,53 @@ def _generator(): return _generator() - eval_loader = _create_filtered_loader(data_module) + # Only create eval_loader if we're not using preassembled core cells + if custom_core_cells_path is None: + eval_loader = _create_filtered_loader(data_module) + else: + eval_loader = None logger.info("Preparing core cells batch and enumerating perturbations...") - control_pert = data_module.get_control_pert() - unique_perts = [] - seen_perts = set() - for batch in eval_loader: - names = _to_list(batch.get("pert_name", [])) - for name in names: - name_value = name.item() if isinstance(name, torch.Tensor) else str(name) - if name_value not in seen_perts: - seen_perts.add(name_value) - unique_perts.append(name_value) - if not unique_perts: - raise RuntimeError("No perturbations found in the provided dataloader.") + # Get control_pert - when using preassembled core cells, get it from the data_module attributes + if custom_core_cells_path is not None: + control_pert = data_module.control_pert + logger.info("Using control_pert from data_module config: %s", control_pert) + else: + control_pert = data_module.get_control_pert() + + # Load pert_onehot_map early to get perturbations without loading data + pert_onehot_map = getattr(data_module, "pert_onehot_map", None) + if pert_onehot_map is None: + map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") + if os.path.exists(map_path): + pert_onehot_map = torch.load(map_path, weights_only=False) + logger.info("Loaded pert_onehot_map from %s", map_path) + else: + logger.warning("pert_onehot_map.pt not found at %s", map_path) + pert_onehot_map = {} + + # When using preassembled core cells, get perturbations from pert_onehot_map (no data loading!) + # Otherwise enumerate from the filtered eval_loader + if custom_core_cells_path is not None and pert_onehot_map: + # Use pert_onehot_map to get all perturbations without loading data + unique_perts = list(pert_onehot_map.keys()) + logger.info("Enumerating %d perturbations from pert_onehot_map (using preassembled core cells)", len(unique_perts)) + else: + # Enumerate from dataloader (original behavior) + unique_perts = [] + seen_perts = set() + for batch in eval_loader: + names = _to_list(batch.get("pert_name", [])) + for name in names: + name_value = name.item() if isinstance(name, torch.Tensor) else str(name) + if name_value not in seen_perts: + seen_perts.add(name_value) + unique_perts.append(name_value) + if not unique_perts: + raise RuntimeError("No perturbations found in the provided dataloader.") target_core_n_default = 256 - custom_core_cells_path = getattr(args, "core_cells_path", None) def _load_core_cells_from_path(path): if not os.path.exists(path): @@ -414,6 +536,40 @@ def _load_core_cells_from_path(path): if custom_core_cells_path: core_cells, target_core_n = _load_core_cells_from_path(os.path.abspath(custom_core_cells_path)) + + # Map latent_embedding to ctrl_cell_emb if needed for model compatibility + if 'latent_embedding' in core_cells and 'ctrl_cell_emb' not in core_cells: + core_cells['ctrl_cell_emb'] = core_cells['latent_embedding'] + logger.info("Mapped latent_embedding to ctrl_cell_emb for model compatibility") + + # Map plate to batch if needed for model compatibility + if 'plate' in core_cells and 'batch' not in core_cells: + plate_data = core_cells['plate'] + # Check if plate is a list of strings (tensor representations) + if isinstance(plate_data, list) and len(plate_data) > 0: + if isinstance(plate_data[0], str) and 'tensor' in plate_data[0].lower(): + # Parse string tensor representations + import re + parsed_tensors = [] + for plate_str in plate_data: + # Extract the numbers from the string representation + numbers = re.findall(r'[-+]?\d*\.?\d+', plate_str) + parsed_tensors.append(torch.tensor([float(n) for n in numbers])) + core_cells['batch'] = torch.stack(parsed_tensors) + logger.info("Parsed %d plate strings to batch tensor with shape %s", len(parsed_tensors), core_cells['batch'].shape) + elif isinstance(plate_data[0], torch.Tensor): + core_cells['batch'] = torch.stack(plate_data) + logger.info("Stacked %d plate tensors to batch tensor", len(plate_data)) + else: + # Assume it's batch indices + core_cells['batch'] = torch.tensor(plate_data) + logger.info("Converted plate list to batch tensor") + elif isinstance(plate_data, torch.Tensor): + core_cells['batch'] = plate_data + logger.info("Using plate tensor as batch") + else: + logger.warning("Could not convert plate to batch, type: %s", type(plate_data)) + logger.info( "Loaded custom core_cells batch with size %d from %s.", target_core_n, @@ -522,14 +678,10 @@ def _append_field(store, key, value): device = next(model.parameters()).device - pert_onehot_map = getattr(data_module, "pert_onehot_map", None) - if pert_onehot_map is None: - map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") - if os.path.exists(map_path): - pert_onehot_map = torch.load(map_path, weights_only=False) - else: - logger.warning("pert_onehot_map.pt not found at %s; proceeding with zero embeddings", map_path) - pert_onehot_map = {} + # pert_onehot_map was already loaded earlier for perturbation enumeration + if not pert_onehot_map: + logger.warning("No pert_onehot_map available; will use zero embeddings for perturbations") + pert_onehot_map = {} def _prepare_pert_emb(pert_name, length): vec = pert_onehot_map.get(pert_name) @@ -567,7 +719,11 @@ def _prepare_pert_emb(pert_name, length): batch_size = batch_preds["preds"].shape[0] metadata["pert_name"].extend(_normalize_field(batch_preds.get("pert_name", pert), batch_size, pert)) metadata["celltype_name"].extend( - _normalize_field(batch_preds.get("celltype_name"), batch_size, target_celltype) + _normalize_field( + batch_preds.get("celltype_name"), + batch_size, + target_celltype, + ) ) metadata["batch"].extend( [None if b is None else str(b) for b in _normalize_field(batch_preds.get("batch"), batch_size)] @@ -580,7 +736,15 @@ def _prepare_pert_emb(pert_name, length): ) batch_pred_np = batch_preds["preds"].detach().cpu().numpy().astype(np.float32) - batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + + # When using preassembled core cells, pert_cell_emb may not exist (double perturbation case) + # Use the input ctrl_cell_emb as the "real" baseline in that case + if batch_preds.get("pert_cell_emb") is not None: + batch_real_np = batch_preds["pert_cell_emb"].detach().cpu().numpy().astype(np.float32) + else: + # Use ctrl_cell_emb (the input embeddings) as the baseline + batch_real_np = batch["ctrl_cell_emb"].detach().cpu().numpy().astype(np.float32) + logger.info("Using ctrl_cell_emb as baseline (pert_cell_emb not available)") first_pass_preds[p_idx, :, :] = batch_pred_np first_pass_real[p_idx, :, :] = batch_real_np From 15bb8ed00e7ec79863eb0f5e8a435947c26e69f7 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Tue, 14 Oct 2025 23:48:10 +0000 Subject: [PATCH 8/9] single fix --- src/state/_cli/_tx/_single.py | 205 +++++++++++++++++++++++++++++++--- 1 file changed, 188 insertions(+), 17 deletions(-) diff --git a/src/state/_cli/_tx/_single.py b/src/state/_cli/_tx/_single.py index 53c39cf3..34f2976c 100644 --- a/src/state/_cli/_tx/_single.py +++ b/src/state/_cli/_tx/_single.py @@ -326,7 +326,7 @@ def _resolve_celltype_key(batch, module): **model_kwargs, } - for optional_key in ("gene_dim", "hvg_dim"): + for optional_key in ("gene_dim", "hvg_dim", "batch_dim"): optional_value = var_dims.get(optional_key) if optional_value is not None: model_init_kwargs[optional_key] = optional_value @@ -334,12 +334,115 @@ def _resolve_celltype_key(batch, module): if "hidden_dim" in var_dims and "hidden_dim" not in model_init_kwargs: model_init_kwargs["hidden_dim"] = var_dims["hidden_dim"] - model = ModelClass.load_from_checkpoint( - checkpoint_path, - **model_init_kwargs, - ) - model.eval() - logger.info("Model loaded successfully.") + # Add embed_key from data module if not already present + if "embed_key" not in model_init_kwargs: + embed_key = getattr(data_module, "embed_key", None) or "latent_embedding" + model_init_kwargs["embed_key"] = embed_key + + # Add output_space from config if not already present + if "output_space" not in model_init_kwargs: + output_space = cfg["data"]["kwargs"].get("output_space", "embedding") + model_init_kwargs["output_space"] = output_space + + # Load checkpoint and handle dimension mismatches + checkpoint_state = torch.load(checkpoint_path, weights_only=False, map_location='cpu') + + # Handle pert_encoder dimension mismatch + pert_encoder_weight_key = "pert_encoder.0.weight" + if pert_encoder_weight_key in checkpoint_state["state_dict"]: + checkpoint_pert_dim = checkpoint_state["state_dict"][pert_encoder_weight_key].shape[1] + current_pert_dim = model_init_kwargs.get("pert_dim", var_dims["pert_dim"]) + + if checkpoint_pert_dim != current_pert_dim: + logger.warning( + "pert_encoder dimension mismatch: checkpoint has %d dims, current model needs %d dims. " + "Using first %d dimensions from checkpoint and zeroing the rest.", + checkpoint_pert_dim, current_pert_dim, min(checkpoint_pert_dim, current_pert_dim) + ) + + # Get checkpoint weights + checkpoint_weight = checkpoint_state["state_dict"][pert_encoder_weight_key] + + # Create new weight tensor with current model dimensions + new_weight = torch.zeros(checkpoint_weight.shape[0], current_pert_dim) + + # Copy available dimensions (use min of both dimensions) + min_dim = min(checkpoint_pert_dim, current_pert_dim) + new_weight[:, :min_dim] = checkpoint_weight[:, :min_dim] + + # Update the checkpoint state dict + checkpoint_state["state_dict"][pert_encoder_weight_key] = new_weight + + logger.info( + "Updated pert_encoder.0.weight: copied %d/%d input dimensions", + min_dim, current_pert_dim + ) + + # Update the model_init_kwargs to match what's actually in the checkpoint + # but keep our adjusted pert_dim + model_init_kwargs["pert_dim"] = current_pert_dim + + # Handle batch_encoder dimension mismatch + batch_encoder_weight_key = "batch_encoder.weight" + if batch_encoder_weight_key in checkpoint_state["state_dict"]: + checkpoint_batch_dim = checkpoint_state["state_dict"][batch_encoder_weight_key].shape[0] + current_batch_dim = model_init_kwargs.get("batch_dim") + + if current_batch_dim is not None and checkpoint_batch_dim != current_batch_dim: + logger.warning( + "batch_encoder dimension mismatch: checkpoint has %d batch categories, current model needs %d. " + "Using first %d categories from checkpoint and zeroing the rest.", + checkpoint_batch_dim, current_batch_dim, min(checkpoint_batch_dim, current_batch_dim) + ) + + # Get checkpoint weights + checkpoint_weight = checkpoint_state["state_dict"][batch_encoder_weight_key] + + # Create new weight tensor with current model dimensions + new_weight = torch.zeros(current_batch_dim, checkpoint_weight.shape[1]) + + # Copy available dimensions (use min of both dimensions) + min_dim = min(checkpoint_batch_dim, current_batch_dim) + new_weight[:min_dim, :] = checkpoint_weight[:min_dim, :] + + # Update the checkpoint state dict + checkpoint_state["state_dict"][batch_encoder_weight_key] = new_weight + + logger.info( + "Updated batch_encoder.weight: copied %d/%d batch categories", + min_dim, current_batch_dim + ) + elif current_batch_dim is None: + # If current model doesn't expect batch_encoder, use checkpoint's batch_dim + model_init_kwargs["batch_dim"] = checkpoint_batch_dim + logger.info("Using checkpoint's batch_dim: %d", checkpoint_batch_dim) + + # Extract additional parameters from checkpoint hyperparameters if available + if "hyper_parameters" in checkpoint_state: + hp = checkpoint_state["hyper_parameters"] + for param_key in ["batch_dim", "cell_sentence_len", "batch_encoder", "predict_mean"]: + if param_key in hp and param_key not in model_init_kwargs: + model_init_kwargs[param_key] = hp[param_key] + logger.debug("Added %s=%s from checkpoint hyperparameters", param_key, hp[param_key]) + + # Save the modified checkpoint to a temporary file and load from there + import tempfile + with tempfile.NamedTemporaryFile(mode='wb', suffix='.ckpt', delete=False) as tmp: + temp_checkpoint_path = tmp.name + torch.save(checkpoint_state, temp_checkpoint_path) + + try: + # Load model using Lightning's checkpoint loading mechanism + model = ModelClass.load_from_checkpoint( + temp_checkpoint_path, + **model_init_kwargs, + ) + model.eval() + logger.info("Model loaded successfully.") + finally: + # Clean up temporary file + import os + os.remove(temp_checkpoint_path) results_dir_default = ( args.results_dir @@ -451,14 +554,13 @@ def _generator(): logger.warning("pert_onehot_map.pt not found at %s", map_path) pert_onehot_map = {} - # When using preassembled core cells, get perturbations from pert_onehot_map (no data loading!) - # Otherwise enumerate from the filtered eval_loader - if custom_core_cells_path is not None and pert_onehot_map: - # Use pert_onehot_map to get all perturbations without loading data + # Use pert_onehot_map to get all perturbations if available, otherwise enumerate from dataloader + if pert_onehot_map: + # Use pert_onehot_map to get all perturbations (preferred approach) unique_perts = list(pert_onehot_map.keys()) - logger.info("Enumerating %d perturbations from pert_onehot_map (using preassembled core cells)", len(unique_perts)) + logger.info("Enumerating %d perturbations from pert_onehot_map", len(unique_perts)) else: - # Enumerate from dataloader (original behavior) + # Fallback: enumerate from dataloader (may be limited by cell type filtering) unique_perts = [] seen_perts = set() for batch in eval_loader: @@ -470,8 +572,14 @@ def _generator(): unique_perts.append(name_value) if not unique_perts: raise RuntimeError("No perturbations found in the provided dataloader.") + logger.warning( + "Using perturbations from filtered dataloader (%d found). " + "This may be limited by --target-cell-type filtering. " + "Consider using pert_onehot_map for complete perturbation coverage.", + len(unique_perts) + ) - target_core_n_default = 256 + target_core_n_default = 64 def _load_core_cells_from_path(path): if not os.path.exists(path): @@ -570,6 +678,40 @@ def _load_core_cells_from_path(path): else: logger.warning("Could not convert plate to batch, type: %s", type(plate_data)) + # Ensure batch tensor matches model's expected batch_dim + if 'batch' in core_cells and isinstance(core_cells['batch'], torch.Tensor): + current_batch_dim = core_cells['batch'].shape[-1] if core_cells['batch'].dim() > 1 else core_cells['batch'].max().item() + 1 + expected_batch_dim = model_init_kwargs.get('batch_dim') + + if expected_batch_dim is not None and current_batch_dim != expected_batch_dim: + logger.warning( + "Batch dimension mismatch: core cells have %d batch categories, model expects %d. " + "Adjusting batch tensor to match model expectations.", + current_batch_dim, expected_batch_dim + ) + + if core_cells['batch'].dim() > 1: + # One-hot encoded batch tensor + if current_batch_dim < expected_batch_dim: + # Pad with zeros + padding_size = expected_batch_dim - current_batch_dim + padding = torch.zeros(core_cells['batch'].shape[0], padding_size) + core_cells['batch'] = torch.cat([core_cells['batch'], padding], dim=1) + logger.info("Padded batch tensor from %d to %d dimensions", current_batch_dim, expected_batch_dim) + elif current_batch_dim > expected_batch_dim: + # Truncate + core_cells['batch'] = core_cells['batch'][:, :expected_batch_dim] + logger.info("Truncated batch tensor from %d to %d dimensions", current_batch_dim, expected_batch_dim) + else: + # Index-based batch tensor - convert to one-hot + batch_indices = core_cells['batch'].long() + # Ensure indices are within expected range + batch_indices = torch.clamp(batch_indices, 0, expected_batch_dim - 1) + # Convert to one-hot + core_cells['batch'] = torch.zeros(batch_indices.shape[0], expected_batch_dim) + core_cells['batch'].scatter_(1, batch_indices.unsqueeze(1), 1) + logger.info("Converted batch indices to one-hot encoding with %d categories", expected_batch_dim) + logger.info( "Loaded custom core_cells batch with size %d from %s.", target_core_n, @@ -687,11 +829,28 @@ def _prepare_pert_emb(pert_name, length): vec = pert_onehot_map.get(pert_name) if vec is None and control_pert in pert_onehot_map: vec = pert_onehot_map[control_pert] + + pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) + if pert_dim <= 0: + raise RuntimeError("pert_dim is undefined; cannot create perturbation embedding") + if vec is None: - pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) - if pert_dim <= 0: - raise RuntimeError("pert_dim is undefined; cannot create perturbation embedding") + # Create zero vector if perturbation not found vec = torch.zeros(pert_dim) + logger.debug("Created zero perturbation vector for %s (not found in pert_onehot_map)", pert_name) + else: + # Handle dimension mismatch between pert_onehot_map and model's pert_dim + if vec.shape[0] != pert_dim: + if vec.shape[0] > pert_dim: + # Truncate if vector is longer than expected (use first pert_dim dimensions) + logger.debug("Truncating perturbation vector for %s from %d to %d", pert_name, vec.shape[0], pert_dim) + vec = vec[:pert_dim] + else: + # Pad with zeros if vector is shorter than expected + logger.debug("Padding perturbation vector for %s from %d to %d", pert_name, vec.shape[0], pert_dim) + padding = torch.zeros(pert_dim - vec.shape[0]) + vec = torch.cat([vec, padding]) + return vec.float().unsqueeze(0).repeat(length, 1).to(device) with torch.no_grad(): @@ -854,6 +1013,17 @@ def _prepare_pert_emb(pert_name, length): ) return + # Check if we have enough perturbations for meaningful evaluation + total_unique_perts = first_pass_real_adata.obs[pert_col].nunique() + if total_unique_perts < 2: + logger.warning( + "Insufficient perturbations for evaluation (%d found). " + "Differential expression analysis requires at least 2 perturbations. " + "This may be due to --target-cell-type filtering or missing pert_onehot_map.", + total_unique_perts + ) + return + pdex_kwargs = dict(exp_post_agg=True, is_log1p=True) for celltype in ct_split_real.keys(): real_ct = ct_split_real[celltype] @@ -888,3 +1058,4 @@ def _prepare_pert_emb(pert_name, length): def save_core_cells_real_preds(args: ap.ArgumentParser) -> None: """Run only phase one of the pipeline and persist real core-cell embeddings per perturbation.""" return run_tx_single(args, phase_one_only=True) + From cad4bd3389cd299503111499a63dedc29ed774d8 Mon Sep 17 00:00:00 2001 From: Dhruv Gautam Date: Wed, 22 Oct 2025 19:40:48 +0000 Subject: [PATCH 9/9] pert map --- scripts/plot_combination_heatmaps.py | 679 +++++++++++++++++++++++++++ src/state/_cli/_tx/_single.py | 332 ++++++++++++- 2 files changed, 995 insertions(+), 16 deletions(-) create mode 100644 scripts/plot_combination_heatmaps.py diff --git a/scripts/plot_combination_heatmaps.py b/scripts/plot_combination_heatmaps.py new file mode 100644 index 00000000..eb270304 --- /dev/null +++ b/scripts/plot_combination_heatmaps.py @@ -0,0 +1,679 @@ +#!/usr/bin/env python +"""Generate per-max-drug heatmaps summarizing Tahoe combination experiments. + +The script expects the directory layout produced by `single_tahoe_combination_average.sh`: + + /data/new_heatmaps/comb_tahoe/ + / + max_drugs_/ + first_pass_preds.npy + core_cells_baseline.npy + first_pass_preds.h5ad + +For every max-drug slice it aggregates scalar summaries of the predicted embeddings +and their baseline-subtracted perturbation effects, assembling one heatmap per slice. +It also emits a companion heatmap that visualises only the perturbation effect +(`predicted - baseline`). +""" + +from __future__ import annotations + +import argparse +import json +import math +import os +from pathlib import Path +from typing import Dict, Iterable, List, Mapping, Optional, Sequence, Tuple + +import numpy as np +import pandas as pd +import seaborn as sns +from matplotlib import pyplot as plt +from tqdm import tqdm +from scipy.cluster import hierarchy + +try: + import anndata as ad # type: ignore +except ImportError as exc: # pragma: no cover - handled at runtime + raise SystemExit( + "The 'anndata' package is required to inspect perturbation metadata. " + "Install it or activate the suitable environment before running this script." + ) from exc + + +MAX_DRUGS_DEFAULT = (1, 2, 4, 8, 16, 32) +PERT_NAME_COLUMNS = ( + "pert_name", + "pert", + "perturbation", + "drugname_drugconc", + "drug_name", +) + + +def parse_args() -> argparse.Namespace: + parser = argparse.ArgumentParser(description=__doc__) + parser.add_argument( + "--input-root", + type=Path, + default=Path("/data/new_heatmaps/comb_tahoe"), + help="Root directory containing per-cell subdirectories (default: /data/new_heatmaps/comb_tahoe).", + ) + parser.add_argument( + "--output-dir", + type=Path, + default=None, + help="Directory to write figures. Defaults to /plots.", + ) + parser.add_argument( + "--max-drugs", + type=int, + nargs="*", + default=MAX_DRUGS_DEFAULT, + help="Max-drug bins to plot (default: %(default)s).", + ) + parser.add_argument( + "--cache-pert-order", + type=Path, + default=None, + help=( + "Optional JSON file to cache the perturbation ordering. " + "Speeds up repeated runs by avoiding re-reading .h5ad files." + ), + ) + parser.add_argument( + "--limit-perts", + type=int, + default=None, + help=( + "Optional cap on the number of perturbations to display (top-N by effect magnitude across cells). " + "Leave unset to plot all perturbations." + ), + ) + parser.add_argument( + "--fig-dpi", + type=int, + default=200, + help="Output DPI for saved figures (default: 200).", + ) + parser.add_argument( + "--cmap", + type=str, + default="viridis", + help="Matplotlib colormap for raw magnitude heatmaps (default: viridis).", + ) + parser.add_argument( + "--diff-cmap", + type=str, + default="coolwarm", + help="Matplotlib colormap for baseline-subtracted heatmaps (default: coolwarm).", + ) + parser.add_argument( + "--column-normalize", + action="store_true", + help=( + "Additionally render column-normalized perturbation-effect heatmaps " + "(each cell type scaled independently to [0, 1])." + ), + ) + parser.add_argument( + "--cluster-summary-dir", + type=Path, + default=None, + help=( + "Optional directory to write CSV summaries of early dendrogram splits per heatmap. " + "If omitted, no summaries are produced." + ), + ) + parser.add_argument( + "--cluster-summary-depth", + type=int, + default=2, + help=( + "Maximum depth (levels from root) of dendrogram splits to summarise when writing " + "cluster information CSVs (default: 2)." + ), + ) + args = parser.parse_args() + if args.limit_perts is not None and args.limit_perts <= 0: + parser.error("--limit-perts must be a positive integer when specified.") + return args + + +def discover_cell_dirs(root: Path, max_drugs: Sequence[int]) -> List[Path]: + if not root.exists(): + raise FileNotFoundError(f"Input root {root} does not exist.") + cell_dirs: List[Path] = [] + for path in sorted(root.iterdir()): + if not path.is_dir() or path.name.startswith("."): + continue + found = False + for bin_value in max_drugs: + candidate = path / f"max_drugs_{bin_value}" + if candidate.is_dir() and (candidate / "first_pass_preds.npy").exists(): + found = True + break + if found: + cell_dirs.append(path) + return cell_dirs + + +def load_perturbation_order( + cell_dirs: Sequence[Path], + max_drugs: Sequence[int], + cache_file: Optional[Path] = None, +) -> List[str]: + if cache_file and cache_file.exists(): + with cache_file.open("r", encoding="utf-8") as handle: + cached = json.load(handle) + if isinstance(cached, list) and all(isinstance(item, str) for item in cached): + return cached # type: ignore[return-value] + + candidates = [] + for cell_dir in cell_dirs: + for max_drug in max_drugs: + h5_path = cell_dir / f"max_drugs_{max_drug}" / "first_pass_preds.h5ad" + if h5_path.exists(): + candidates.append(h5_path) + if candidates: + break + + if not candidates: + raise RuntimeError( + "Could not locate any 'first_pass_preds.h5ad' files. " + "Ensure the combination experiments have completed." + ) + + # Use the first available file to derive perturbation ordering + pert_names: List[str] = [] + for h5_file in tqdm(candidates, desc="Scanning AnnData files", unit="file"): + adata = ad.read_h5ad(h5_file, backed="r") + available_cols = list(adata.obs.columns) + target_col = None + for col in PERT_NAME_COLUMNS: + if col in adata.obs: + target_col = col + break + if target_col is None and available_cols: + # Heuristic fallback: use the first column with string-like data + for col in available_cols: + series = adata.obs[col] + sample = series.iloc[0] if len(series) else None + if isinstance(sample, (str, bytes)): + target_col = col + break + + if target_col is None: + adata.file.close() + continue + + obs_names = pd.Index(adata.obs[target_col].astype(str)) + pert_names = obs_names.drop_duplicates().tolist() + adata.file.close() + if pert_names: + break + + if not pert_names: + raise RuntimeError( + "Unable to extract perturbation names from available AnnData files. " + "Checked columns: %s" % ", ".join(PERT_NAME_COLUMNS) + ) + + if cache_file: + cache_file.parent.mkdir(parents=True, exist_ok=True) + with cache_file.open("w", encoding="utf-8") as handle: + json.dump(pert_names, handle) + + return pert_names + + +def compute_metrics_for_cell( + cell_dir: Path, + max_drug: int, + baseline_key: str = "ctrl_cell_emb", +) -> Optional[Tuple[np.ndarray, np.ndarray]]: + """Return (raw_norm, baseline_sub_norm) arrays for every perturbation.""" + base_path = cell_dir / f"max_drugs_{max_drug}" + preds_path = base_path / "first_pass_preds.npy" + baseline_path = base_path / "core_cells_baseline.npy" + + if not preds_path.exists() or not baseline_path.exists(): + return None + + preds = np.load(preds_path, mmap_mode="r") + # baseline file stores dictionary + baseline_payload = np.load(baseline_path, allow_pickle=True).item() + if baseline_key not in baseline_payload: + raise KeyError( + f"Baseline dictionary at {baseline_path} does not contain key '{baseline_key}'. " + f"Available keys: {list(baseline_payload.keys())}" + ) + baseline = np.asarray(baseline_payload[baseline_key], dtype=np.float32) + if baseline.ndim != 2: + raise ValueError( + f"Expected baseline tensor to be 2D (, ); got shape {baseline.shape}." + ) + + num_perts = preds.shape[0] + raw_norm = np.empty(num_perts, dtype=np.float32) + effect_norm = np.empty(num_perts, dtype=np.float32) + + # Precompute baseline norms once + for idx in range(num_perts): + pert_block = np.asarray(preds[idx], dtype=np.float32) + if pert_block.shape != baseline.shape: + raise ValueError( + f"Perturbation block shape {pert_block.shape} does not match baseline shape {baseline.shape} " + f"for {cell_dir.name} max_drugs_{max_drug} index {idx}." + ) + raw_norm[idx] = np.linalg.norm(pert_block, axis=1).mean() + diff = pert_block - baseline + effect_norm[idx] = np.linalg.norm(diff, axis=1).mean() + + return raw_norm, effect_norm + + +def column_normalize(matrix: np.ndarray) -> np.ndarray: + """Scale each column to [0, 1] using per-column min/max (ignoring NaNs).""" + normalized = matrix.copy() + with np.errstate(invalid="ignore"): + col_min = np.nanmin(normalized, axis=0) + col_max = np.nanmax(normalized, axis=0) + denom = col_max - col_min + valid = denom > 1e-12 + if np.any(valid): + normalized[:, valid] = (normalized[:, valid] - col_min[valid]) / denom[valid] + if np.any(~valid): + normalized[:, ~valid] = 0.0 + normalized = np.clip(normalized, 0.0, 1.0) + return normalized + + +def assemble_heatmap_data( + cell_dirs: Sequence[Path], + max_drugs: Sequence[int], + pert_order: Sequence[str], + normalize_columns: bool = False, +) -> Dict[int, Dict[str, np.ndarray]]: + """Return nested dict: {max_drug: {'raw': matrix, 'effect': matrix}}.""" + num_perts = len(pert_order) + cell_names = [cell_dir.name for cell_dir in cell_dirs] + name_to_index = {name: idx for idx, name in enumerate(cell_names)} + + heatmaps: Dict[int, Dict[str, np.ndarray]] = {} + for max_drug in tqdm(max_drugs, desc="Aggregating heatmaps", unit="bin"): + raw_matrix = np.full((num_perts, len(cell_dirs)), np.nan, dtype=np.float32) + effect_matrix = np.full_like(raw_matrix, np.nan) + + for cell_dir in tqdm( + cell_dirs, + desc=f"Cells for max_drugs={max_drug}", + unit="cell", + leave=False, + ): + metrics = compute_metrics_for_cell(cell_dir, max_drug) + if metrics is None: + continue + raw_norm, effect_norm = metrics + if raw_norm.shape[0] != num_perts: + raise ValueError( + f"Perturbation count mismatch for {cell_dir.name} max_drugs_{max_drug}: " + f"{raw_norm.shape[0]} vs expected {num_perts}." + ) + col_idx = name_to_index[cell_dir.name] + raw_matrix[:, col_idx] = raw_norm + effect_matrix[:, col_idx] = effect_norm + + payload: Dict[str, np.ndarray] = {"raw": raw_matrix, "effect": effect_matrix} + if normalize_columns: + payload["effect_colnorm"] = column_normalize(effect_matrix) + heatmaps[max_drug] = payload + + return heatmaps + + +def maybe_limit_perturbations( + matrices: Mapping[int, Mapping[str, np.ndarray]], + pert_names: Sequence[str], + limit: Optional[int], +) -> Tuple[Sequence[str], Dict[int, Dict[str, np.ndarray]]]: + if limit is None or limit >= len(pert_names): + return pert_names, matrices # type: ignore[return-value] + + # Rank perturbations by global effect magnitude (mean across cells and max_drug bins) + combined_effects = [] + for max_drug, payload in matrices.items(): + effect_matrix = payload["effect"] + combined_effects.append(effect_matrix) + stacked = np.stack(combined_effects, axis=0) # shape (num_bins, num_perts, num_cells) + global_scores = np.nanmean(stacked, axis=(0, 2)) + top_indices = np.argsort(global_scores)[::-1][:limit] + top_indices = np.sort(top_indices) + + reduced_names = [pert_names[idx] for idx in top_indices] + reduced: Dict[int, Dict[str, np.ndarray]] = {} + for max_drug, payload in matrices.items(): + reduced[max_drug] = { + key: value[top_indices, :] + for key, value in payload.items() + } + + return reduced_names, reduced + + +def _collect_leaf_indices(node: hierarchy.ClusterNode) -> List[int]: + """Return sorted leaf indices beneath a SciPy cluster node.""" + if node.is_leaf(): + return [node.id] + leaves: List[int] = [] + if node.left is not None: + leaves.extend(_collect_leaf_indices(node.left)) + if node.right is not None: + leaves.extend(_collect_leaf_indices(node.right)) + return sorted(leaves) + + +def _summarize_branch( + node: hierarchy.ClusterNode, + labels: Sequence[str], + data: np.ndarray, + axis: str, + split_depth: int, +) -> Dict[str, object]: + indices = _collect_leaf_indices(node) + members = [labels[idx] for idx in indices] + if axis == "perturbation": + subset = data[np.ix_(indices, range(data.shape[1]))] + column_means = np.nanmean(subset, axis=0) + mean_range = float(np.nanmax(column_means) - np.nanmin(column_means)) + axis_range = mean_range + axis_min = float(np.nanmin(column_means)) if column_means.size else np.nan + axis_max = float(np.nanmax(column_means)) if column_means.size else np.nan + else: + subset = data[np.ix_(range(data.shape[0]), indices)] + row_means = np.nanmean(subset, axis=1) + mean_range = float(np.nanmax(row_means) - np.nanmin(row_means)) + axis_range = mean_range + axis_min = float(np.nanmin(row_means)) if row_means.size else np.nan + axis_max = float(np.nanmax(row_means)) if row_means.size else np.nan + + summary: Dict[str, object] = { + "axis": axis, + "split_depth": split_depth, + "node_height": float(node.dist), + "member_count": len(indices), + "mean_effect": float(np.nanmean(subset)), + "effect_std": float(np.nanstd(subset)), + "axis_mean_min": axis_min, + "axis_mean_max": axis_max, + "axis_mean_range": axis_range, + "members": ";".join(members), + } + return summary + + +def summarize_dendrogram( + linkage: np.ndarray, + labels: Sequence[str], + data: np.ndarray, + axis: str, + max_depth: int, +) -> List[Dict[str, object]]: + """Extract summaries for early dendrogram splits up to `max_depth` levels.""" + if linkage.size == 0 or len(labels) <= 1: + return [] + + root = hierarchy.to_tree(linkage, rd=False) + summaries: List[Dict[str, object]] = [] + queue: List[Tuple[hierarchy.ClusterNode, int]] = [(root, 0)] + + while queue: + node, depth = queue.pop(0) + if node.is_leaf() or depth >= max_depth: + continue + # Record both child branches for this split + for child in (node.left, node.right): + if child is None: + continue + summaries.append( + _summarize_branch(child, labels, data, axis=axis, split_depth=depth + 1) + ) + if not child.is_leaf(): + queue.append((child, depth + 1)) + + # Order summaries by descending split depth then node height (largest first) + summaries.sort(key=lambda item: (-item["split_depth"], -item["node_height"])) + return summaries + + +def write_cluster_summary( + records: Sequence[Dict[str, object]], + output_path: Path, +) -> None: + if not records: + return + df = pd.DataFrame(records) + df["axis_rank"] = df["axis"].map({"cell": 0, "perturbation": 1}).fillna(99) + df.sort_values( + by=["axis_rank", "axis", "split_depth", "node_height"], + ascending=[True, True, False, False], + inplace=True, + ) + df.drop(columns=["axis_rank"], inplace=True) + output_path.parent.mkdir(parents=True, exist_ok=True) + df.to_csv(output_path, index=False) + + +def plot_heatmap( + data: np.ndarray, + pert_names: Sequence[str], + cell_names: Sequence[str], + title: str, + cmap: str, + output_path: Path, + dpi: int = 200, + vmin: Optional[float] = None, + vmax: Optional[float] = None, + summary_depth: Optional[int] = None, +) -> Optional[List[Dict[str, object]]]: + """Render and save a single heatmap.""" + df = pd.DataFrame(data, index=pert_names, columns=cell_names) + height = max(6.0, min(0.02 * len(pert_names), 30.0)) + width = max(8.0, min(0.35 * len(cell_names), 40.0)) + cluster_grid = sns.clustermap( + df, + cmap=cmap, + vmin=vmin, + vmax=vmax, + figsize=(width, height), + cbar_kws={"label": "Mean L2 norm"}, + ) + + if len(cell_names) > 1: + reordered = cluster_grid.dendrogram_col.reordered_ind + col_means = df.iloc[:, reordered].mean(axis=0, skipna=True) + if col_means.iloc[0] > col_means.iloc[-1]: + cluster_grid.ax_heatmap.invert_xaxis() + cluster_grid.ax_col_dendrogram.invert_xaxis() + + cluster_grid.ax_heatmap.set_xlabel("Cell Type in Tahoe", fontsize=24) + cluster_grid.ax_heatmap.set_ylabel("Genetic Perturbation Reconstructed Through Drug Combinations", fontsize=24) + cluster_grid.ax_heatmap.set_title(title) + cluster_grid.ax_heatmap.tick_params(axis="x", labelrotation=90) + cluster_grid.fig.tight_layout() + output_path.parent.mkdir(parents=True, exist_ok=True) + cluster_grid.savefig(output_path, dpi=dpi) + plt.close(cluster_grid.fig) + + if summary_depth is not None and summary_depth > 0: + summaries: List[Dict[str, object]] = [] + row_linkage = cluster_grid.dendrogram_row.linkage + col_linkage = cluster_grid.dendrogram_col.linkage + if row_linkage is not None and len(pert_names) > 1: + summaries.extend( + summarize_dendrogram( + row_linkage, list(pert_names), df.values, axis="perturbation", max_depth=summary_depth + ) + ) + if col_linkage is not None and len(cell_names) > 1: + summaries.extend( + summarize_dendrogram( + col_linkage, list(cell_names), df.values, axis="cell", max_depth=summary_depth + ) + ) + return summaries + return None + + +def main() -> None: + args = parse_args() + cell_dirs = discover_cell_dirs(args.input_root, args.max_drugs) + if not cell_dirs: + raise RuntimeError(f"No cell directories found under {args.input_root}") + + output_dir = args.output_dir or (args.input_root / "plots") + output_dir.mkdir(parents=True, exist_ok=True) + summary_dir = args.cluster_summary_dir + summary_depth: Optional[int] = None + if summary_dir is not None: + summary_dir.mkdir(parents=True, exist_ok=True) + summary_depth = max(0, args.cluster_summary_depth) + + pert_order = load_perturbation_order( + cell_dirs, + args.max_drugs, + cache_file=args.cache_pert_order, + ) + + heatmap_payload = assemble_heatmap_data( + cell_dirs, + args.max_drugs, + pert_order, + normalize_columns=args.column_normalize, + ) + + cell_names = [cell_dir.name for cell_dir in cell_dirs] + pert_names, reduced_payload = maybe_limit_perturbations( + heatmap_payload, + pert_order, + args.limit_perts, + ) + + # Determine shared color limits for comparability + effect_values = [ + payload["effect"] for payload in reduced_payload.values() + ] + combined_effect = np.concatenate([arr.flatten() for arr in effect_values]) + combined_effect = combined_effect[~np.isnan(combined_effect)] + effect_max = float(combined_effect.max()) if combined_effect.size else None + + raw_values = [ + payload["raw"] for payload in reduced_payload.values() + ] + combined_raw = np.concatenate([arr.flatten() for arr in raw_values]) + combined_raw = combined_raw[~np.isnan(combined_raw)] + raw_max = float(combined_raw.max()) if combined_raw.size else None + + for max_drug, matrices in sorted(reduced_payload.items()): + if raw_max: + raw_path = output_dir / f"heatmap_max_drugs_{max_drug}_raw.png" + plot_heatmap( + matrices["raw"], + pert_names, + cell_names, + title=f"Predicted embedding norm – max_drugs={max_drug}", + cmap=args.cmap, + output_path=raw_path, + dpi=args.fig_dpi, + vmin=0.0, + vmax=raw_max, + ) + effect_path = output_dir / f"heatmap_max_drugs_{max_drug}_effect.png" + effect_summaries = plot_heatmap( + matrices["effect"], + pert_names, + cell_names, + title=f"Perturbation effect (predicted - baseline) – max_drugs={max_drug}", + cmap=args.diff_cmap, + output_path=effect_path, + dpi=args.fig_dpi, + vmin=0.0, + vmax=effect_max, + summary_depth=summary_depth, + ) + if summary_dir is not None and effect_summaries: + summary_path = summary_dir / f"heatmap_max_drugs_{max_drug}_effect_clusters.csv" + write_cluster_summary(effect_summaries, summary_path) + if args.column_normalize and "effect_colnorm" in matrices: + effect_norm_path = output_dir / f"heatmap_max_drugs_{max_drug}_effect_colnorm.png" + plot_heatmap( + matrices["effect_colnorm"], + pert_names, + cell_names, + title=( + "Perturbation effect (baseline subtracted, column-normalized) " + f"– max_drugs={max_drug}" + ), + cmap=args.diff_cmap, + output_path=effect_norm_path, + dpi=args.fig_dpi, + vmin=0.0, + vmax=1.0, + ) + + # Combined overview for perturbation effects across all bins + combined_fig = output_dir / "perturbation_effect_overview.png" + stacked_effects = [] + stacked_columns = [] + for max_drug, matrices in sorted(reduced_payload.items()): + stacked_effects.append(matrices["effect"]) + stacked_columns.extend( + [f"{cell}-max{max_drug}" for cell in cell_names] + ) + if stacked_effects: + merged = np.concatenate(stacked_effects, axis=1) + overview_summaries = plot_heatmap( + merged, + pert_names, + stacked_columns, + title="Perturbation effect overview (baseline subtracted)", + cmap=args.diff_cmap, + output_path=combined_fig, + dpi=args.fig_dpi, + vmin=0.0, + vmax=effect_max, + summary_depth=summary_depth, + ) + if summary_dir is not None and overview_summaries: + overview_path = summary_dir / "perturbation_effect_overview_clusters.csv" + write_cluster_summary(overview_summaries, overview_path) + if args.column_normalize: + stacked_norm = [] + stacked_norm_cols = [] + for max_drug, matrices in sorted(reduced_payload.items()): + if "effect_colnorm" not in matrices: + continue + stacked_norm.append(matrices["effect_colnorm"]) + stacked_norm_cols.extend( + [f"{cell}-max{max_drug}" for cell in cell_names] + ) + if stacked_norm: + merged_norm = np.concatenate(stacked_norm, axis=1) + combined_norm_fig = output_dir / "perturbation_effect_overview_colnorm.png" + plot_heatmap( + merged_norm, + pert_names, + stacked_norm_cols, + title="Perturbation effect overview (column-normalized)", + cmap=args.diff_cmap, + output_path=combined_norm_fig, + dpi=args.fig_dpi, + vmin=0.0, + vmax=1.0, + ) + + print(f"Saved heatmaps to {output_dir}") + + +if __name__ == "__main__": + main() diff --git a/src/state/_cli/_tx/_single.py b/src/state/_cli/_tx/_single.py index 34f2976c..328dc3fa 100644 --- a/src/state/_cli/_tx/_single.py +++ b/src/state/_cli/_tx/_single.py @@ -77,12 +77,22 @@ def add_arguments_single(parser: ap.ArgumentParser) -> None: "Path to a TOML data configuration file to override the data paths in the loaded data module." ), ) + parser.add_argument( + "--perturbation-npy", + type=str, + default=None, + help=( + "Optional path to a NumPy .npy/.npz file (or convertible CSV specification) that maps perturbation" + " names to explicit encoder vectors. When provided, these vectors override the default pert_onehot_map." + ), + ) def run_tx_single(args: ap.ArgumentParser, *, phase_one_only: bool = False) -> None: import logging import os import sys + import re import anndata import lightning.pytorch as pl @@ -173,6 +183,226 @@ def _resolve_celltype_key(batch, module): return key, ordered_candidates return None, ordered_candidates + def _ensure_tensor_vector(value, *, expected_dim=None, context="perturbation"): + if isinstance(value, torch.Tensor): + tensor = value.detach().clone().to(dtype=torch.float32, device="cpu") + else: + try: + tensor = torch.as_tensor(value, dtype=torch.float32) + except Exception as exc: + raise TypeError(f"Could not convert {context} value to tensor: {exc}") from exc + if tensor.ndim > 1: + tensor = tensor.reshape(-1) + if tensor.ndim != 1: + raise ValueError( + f"Expected a 1D tensor for {context}, but received shape {tuple(tensor.shape)}." + ) + if expected_dim is not None and tensor.numel() != expected_dim: + raise ValueError( + f"Dimension mismatch for {context}: expected {expected_dim}, received {tensor.numel()}." + ) + return tensor.contiguous() + + def _normalize_pert_map(raw_map, *, expected_dim=None, label="perturbation map"): + if raw_map is None: + return {} + normalized = {} + for raw_key, raw_value in raw_map.items(): + key = str(raw_key) + try: + tensor = _ensure_tensor_vector( + raw_value, + expected_dim=expected_dim, + context=f"{label}:{key}", + ) + except Exception as exc: + raise ValueError(f"Failed to normalize {label} entry '{key}': {exc}") from exc + normalized[key] = tensor + return normalized + + def _infer_combination_csv(path): + directory, filename = os.path.split(path) + match = re.search(r"max[_-]?drugs[_-]?(\d+)", filename, flags=re.IGNORECASE) + if not match: + return None + suffix = match.group(1) + candidates = [ + os.path.join(directory, f"average_to_genetic_reconstruction_maxdrugs{suffix}.csv"), + os.path.join(directory, f"average_to_genetic_reconstruction_maxdrugs{suffix}.CSV"), + ] + for candidate in candidates: + if os.path.exists(candidate): + return candidate + return None + + def _parse_combination_spec(spec): + if spec is None: + return {} + if isinstance(spec, (float, np.floating)) and np.isnan(spec): + return {} + if not isinstance(spec, str): + spec = str(spec) + spec = spec.strip() + if not spec: + return {} + components = {} + for part in spec.split(";"): + piece = part.strip() + if not piece: + continue + if ":" not in piece: + raise ValueError(f"Invalid combination component '{piece}' (missing weight separator).") + combo_key, weight_str = piece.rsplit(":", 1) + combo_key = combo_key.strip() + if not combo_key: + raise ValueError(f"Invalid combination component '{piece}' (empty key).") + try: + weight = float(weight_str.strip()) + except Exception as exc: + raise ValueError(f"Invalid weight value '{weight_str}' in component '{piece}': {exc}") from exc + components[combo_key] = components.get(combo_key, 0.0) + weight + return components + + def _build_map_from_combination_table(csv_path, base_map, expected_dim): + if expected_dim is None: + raise ValueError( + "Cannot construct perturbation vectors from combination table without a known pert_dim." + ) + if not base_map: + raise ValueError( + "Base perturbation map is empty; cannot expand combinations without reference encodings." + ) + df = pd.read_csv(csv_path) + constructed = {} + missing_components = set() + zero_combo_genes = [] + for row in df.itertuples(index=False): + gene_name = str(getattr(row, "gene")) + combo_spec = getattr(row, "combination", "") + components = _parse_combination_spec(combo_spec) + if not components: + zero_combo_genes.append(gene_name) + continue + vector = torch.zeros(expected_dim, dtype=torch.float32) + for combo_key, weight in components.items(): + reference = base_map.get(combo_key) + if reference is None: + missing_components.add(combo_key) + continue + vector = vector + reference.to(dtype=torch.float32) * float(weight) + constructed[gene_name] = vector + if missing_components: + preview = ", ".join(sorted(missing_components)[:10]) + raise KeyError( + "Encountered %d combination components that were not present in the base perturbation map. " + "Examples: %s" % (len(missing_components), preview) + ) + if zero_combo_genes: + logger.warning( + "Skipped %d genes with empty combination specifications when building perturbation map from %s.", + len(zero_combo_genes), + csv_path, + ) + logger.info( + "Constructed %d custom perturbation vectors from %s.", + len(constructed), + csv_path, + ) + return constructed + + def _serialize_pert_map_for_save(pert_map): + serialized = {} + for key, tensor in pert_map.items(): + serialized[key] = tensor.detach().cpu().numpy() + return serialized + + def _load_numpy_mapping(path): + loaded = np.load(path, allow_pickle=True) + if isinstance(loaded, np.lib.npyio.NpzFile): + return {k: loaded[k] for k in loaded.files} + if isinstance(loaded, np.ndarray): + if loaded.dtype == object: + if loaded.shape == (): + return loaded.item() + mapping = {} + for entry in loaded.tolist(): + if isinstance(entry, (tuple, list)) and len(entry) == 2: + mapping[str(entry[0])] = entry[1] + if mapping: + return mapping + raise ValueError( + f"Unsupported array format when loading perturbation map from {path}." + ) + if isinstance(loaded, dict): + return loaded + raise TypeError(f"Unsupported data type {type(loaded).__name__} in {path}.") + + def _load_custom_perturbation_map( + path, + base_map, + expected_dim, + ): + if path is None: + return None, None + + resolved_path = path + data = None + source_description = None + extension = os.path.splitext(resolved_path)[1].lower() + + if os.path.exists(resolved_path) and extension in {".npy", ".npz"}: + data = _load_numpy_mapping(resolved_path) + source_description = resolved_path + elif os.path.exists(resolved_path) and extension == ".csv": + data = _build_map_from_combination_table(resolved_path, base_map, expected_dim) + source_description = resolved_path + else: + candidate_csv = _infer_combination_csv(resolved_path) + if candidate_csv and os.path.exists(candidate_csv): + data = _build_map_from_combination_table(candidate_csv, base_map, expected_dim) + source_description = candidate_csv + if extension in {".npy", ".npz"} or extension == "": + try: + serializable = _serialize_pert_map_for_save(data) + if not os.path.exists(resolved_path): + import tempfile + with tempfile.NamedTemporaryFile(suffix=".npy", delete=False) as tmp_file: + tmp_path = tmp_file.name + try: + np.save(tmp_path, serializable, allow_pickle=True) + os.replace(tmp_path, resolved_path) + logger.info("Saved converted perturbation map to %s", resolved_path) + finally: + if os.path.exists(tmp_path): + try: + os.remove(tmp_path) + except OSError: + pass + else: + logger.debug("Perturbation map already exists at %s; skipping save", resolved_path) + except Exception as exc: + logger.warning( + "Failed to save converted perturbation map to %s: %s", + resolved_path, + exc, + ) + else: + raise FileNotFoundError( + f"Custom perturbation specification {resolved_path} not found. " + "Provide an existing .npy/.npz file or a CSV with combination specifications." + ) + + if isinstance(data, dict): + normalized = _normalize_pert_map( + data, + expected_dim=expected_dim, + label="custom perturbation map", + ) + else: + normalized = data + + return normalized, source_description + torch.multiprocessing.set_sharing_strategy("file_system") config_path = os.path.join(args.output_dir, "config.yaml") @@ -543,24 +773,87 @@ def _generator(): else: control_pert = data_module.get_control_pert() - # Load pert_onehot_map early to get perturbations without loading data - pert_onehot_map = getattr(data_module, "pert_onehot_map", None) - if pert_onehot_map is None: + custom_perturbation_path = getattr(args, "perturbation_npy", None) + + base_perturbation_map_raw = getattr(data_module, "pert_onehot_map", None) + if base_perturbation_map_raw is None: map_path = os.path.join(run_output_dir, "pert_onehot_map.pt") if os.path.exists(map_path): - pert_onehot_map = torch.load(map_path, weights_only=False) - logger.info("Loaded pert_onehot_map from %s", map_path) + base_perturbation_map_raw = torch.load(map_path, weights_only=False) + logger.info("Loaded base pert_onehot_map from %s", map_path) else: logger.warning("pert_onehot_map.pt not found at %s", map_path) - pert_onehot_map = {} - + base_perturbation_map_raw = {} + + base_perturbation_map = _normalize_pert_map( + base_perturbation_map_raw, + label="base perturbation map", + ) + if base_perturbation_map: + logger.info("Base perturbation map contains %d entries.", len(base_perturbation_map)) + else: + logger.warning("Base perturbation map is empty; custom perturbations may be required.") + + base_expected_dim = None + if base_perturbation_map: + base_expected_dim = next(iter(base_perturbation_map.values())).numel() + + expected_pert_dim = base_expected_dim + var_dims_preview = None + if custom_perturbation_path or base_expected_dim is None: + try: + var_dims_preview = data_module.get_var_dims() + except Exception as exc: + logger.debug("Unable to preview var_dims prior to perturbation setup: %s", exc) + var_dims_preview = None + else: + if isinstance(var_dims_preview, dict) and var_dims_preview.get("pert_dim") is not None: + expected_pert_dim = var_dims_preview["pert_dim"] + + custom_perturbation_map = {} + custom_map_source = None + if custom_perturbation_path: + logger.info("Attempting to load custom perturbation vectors from %s", custom_perturbation_path) + custom_perturbation_map, custom_map_source = _load_custom_perturbation_map( + custom_perturbation_path, + base_perturbation_map, + expected_pert_dim, + ) + if not custom_perturbation_map: + raise ValueError( + f"No perturbation vectors were loaded from {custom_map_source or custom_perturbation_path}." + ) + expected_pert_dim = next(iter(custom_perturbation_map.values())).numel() + if control_pert and control_pert not in custom_perturbation_map: + base_control = base_perturbation_map.get(control_pert) + if base_control is not None: + custom_perturbation_map[control_pert] = base_control.clone() + logger.info( + "Added control perturbation '%s' to custom perturbation map using base encoding.", + control_pert, + ) + logger.info( + "Using custom perturbation map with %d entries (source: %s).", + len(custom_perturbation_map), + custom_map_source or custom_perturbation_path, + ) + + if custom_perturbation_map: + pert_onehot_map = custom_perturbation_map + fallback_perturbation_map = base_perturbation_map + else: + pert_onehot_map = base_perturbation_map + fallback_perturbation_map = None + # Use pert_onehot_map to get all perturbations if available, otherwise enumerate from dataloader if pert_onehot_map: - # Use pert_onehot_map to get all perturbations (preferred approach) unique_perts = list(pert_onehot_map.keys()) - logger.info("Enumerating %d perturbations from pert_onehot_map", len(unique_perts)) + logger.info( + "Enumerating %d perturbations from %s perturbation map", + len(unique_perts), + "custom" if custom_perturbation_map else "base", + ) else: - # Fallback: enumerate from dataloader (may be limited by cell type filtering) unique_perts = [] seen_perts = set() for batch in eval_loader: @@ -575,8 +868,8 @@ def _generator(): logger.warning( "Using perturbations from filtered dataloader (%d found). " "This may be limited by --target-cell-type filtering. " - "Consider using pert_onehot_map for complete perturbation coverage.", - len(unique_perts) + "Consider supplying a perturbation map for complete perturbation coverage.", + len(unique_perts), ) target_core_n_default = 64 @@ -822,13 +1115,20 @@ def _append_field(store, key, value): # pert_onehot_map was already loaded earlier for perturbation enumeration if not pert_onehot_map: - logger.warning("No pert_onehot_map available; will use zero embeddings for perturbations") + logger.warning("No perturbation map available; will use zero embeddings for perturbations") pert_onehot_map = {} def _prepare_pert_emb(pert_name, length): vec = pert_onehot_map.get(pert_name) - if vec is None and control_pert in pert_onehot_map: - vec = pert_onehot_map[control_pert] + if vec is None and fallback_perturbation_map: + vec = fallback_perturbation_map.get(pert_name) + if vec is not None: + logger.debug("Using fallback perturbation vector for %s from base map.", pert_name) + if vec is None and control_pert: + if control_pert in pert_onehot_map: + vec = pert_onehot_map[control_pert] + elif fallback_perturbation_map and control_pert in fallback_perturbation_map: + vec = fallback_perturbation_map[control_pert] pert_dim = getattr(model, "pert_dim", var_dims.get("pert_dim", 0)) if pert_dim <= 0: @@ -839,6 +1139,7 @@ def _prepare_pert_emb(pert_name, length): vec = torch.zeros(pert_dim) logger.debug("Created zero perturbation vector for %s (not found in pert_onehot_map)", pert_name) else: + vec = vec.clone().detach().cpu() # Handle dimension mismatch between pert_onehot_map and model's pert_dim if vec.shape[0] != pert_dim: if vec.shape[0] > pert_dim: @@ -1058,4 +1359,3 @@ def _prepare_pert_emb(pert_name, length): def save_core_cells_real_preds(args: ap.ArgumentParser) -> None: """Run only phase one of the pipeline and persist real core-cell embeddings per perturbation.""" return run_tx_single(args, phase_one_only=True) -