Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
56 changes: 46 additions & 10 deletions src/state/_cli/_tx/_infer.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,7 +95,7 @@ def run_tx_infer(args: argparse.Namespace):
import yaml
from tqdm import tqdm

from ...tx.models.state_transition import StateTransitionPerturbationModel
from ...tx.utils import get_lightning_module

# -----------------------
# Helpers
Expand Down Expand Up @@ -157,6 +157,7 @@ def prepare_batch(
batch_indices: Optional[torch.Tensor],
pert_names: List[str],
device: torch.device,
treated_np: Optional[np.ndarray] = None,
) -> Dict[str, torch.Tensor | List[str]]:
"""
Construct a model batch with variable-length sentence (B=1, S=T, ...).
Expand All @@ -170,6 +171,8 @@ def prepare_batch(
}
if batch_indices is not None:
batch["batch"] = batch_indices.to(device) # [T]
if treated_np is not None:
batch["pert_cell_emb"] = torch.tensor(treated_np, dtype=torch.float32, device=device)
return batch

def pad_adata_with_tsv(
Expand Down Expand Up @@ -308,6 +311,13 @@ def pad_adata_with_tsv(
if not args.quiet:
print(f"Loaded config: {config_path}")

# dimensionalities
var_dims_path = os.path.join(args.model_dir, "var_dims.pkl")
if not os.path.exists(var_dims_path):
raise FileNotFoundError(f"Missing var_dims.pkl at {var_dims_path}")
with open(var_dims_path, "rb") as f:
var_dims = pickle.load(f)

# control_pert
control_pert = args.control_pert
if control_pert is None:
Expand Down Expand Up @@ -346,13 +356,6 @@ def pad_adata_with_tsv(
except Exception:
args.batch_col = None

# dimensionalities
var_dims_path = os.path.join(args.model_dir, "var_dims.pkl")
if not os.path.exists(var_dims_path):
raise FileNotFoundError(f"Missing var_dims.pkl at {var_dims_path}")
with open(var_dims_path, "rb") as f:
var_dims = pickle.load(f)

pert_dim = var_dims.get("pert_dim")
batch_dim = var_dims.get("batch_dim", None)

Expand All @@ -378,9 +381,32 @@ def pad_adata_with_tsv(
if not args.quiet:
print(f"No --checkpoint given, using {checkpoint_path}")

model = StateTransitionPerturbationModel.load_from_checkpoint(checkpoint_path)
model_cfg = cfg.get("model", {})
model_name = model_cfg.get("name", "state")
model_kwargs = dict(model_cfg.get("kwargs", {}))
training_cfg = dict(cfg.get("training", {}))
data_kwargs = dict(cfg.get("data", {}).get("kwargs", {}))

device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

model = get_lightning_module(
model_name,
data_kwargs,
model_kwargs,
training_cfg,
var_dims,
)

checkpoint = torch.load(checkpoint_path, map_location=device, weights_only=False)
state_dict = checkpoint.get("state_dict", checkpoint)
missing, unexpected = model.load_state_dict(state_dict, strict=False)
if missing and not args.quiet:
print(f"Warning: missing parameters when loading checkpoint: {sorted(missing)}")
if unexpected and not args.quiet:
print(f"Warning: unexpected parameters when loading checkpoint: {sorted(unexpected)}")

model.to(device)
model.eval()
device = next(model.parameters()).device
cell_set_len = args.max_set_len if args.max_set_len is not None else getattr(model, "cell_sentence_len", 256)
uses_batch_encoder = getattr(model, "batch_encoder", None) is not None
output_space = getattr(model, "output_space", cfg.get("data", {}).get("kwargs", {}).get("output_space", "gene"))
Expand Down Expand Up @@ -506,6 +532,14 @@ def pad_adata_with_tsv(
sim_obsm = X_in.copy()
out_target = f"obsm['{writes_to[1]}']"

model_mode = getattr(model, "mode", None)
is_inverse = model_mode == "inverse"
if is_inverse:
target_dim = getattr(model, "num_nodes", X_in.shape[1])
sim_obsm = np.zeros((n_total, target_dim), dtype=np.float32)
writes_to = (".obsm", "pdgrapher_inverse")
out_target = f"obsm['{writes_to[1]}']"

# Group labels for set-to-set behavior
if args.celltype_col and args.celltype_col in adata.obs:
group_labels = adata.obs[args.celltype_col].astype(str).values
Expand Down Expand Up @@ -594,12 +628,14 @@ def group_control_indices(group_name: str) -> np.ndarray:
bi = None

# 4) Forward pass (homogeneous pert in this window)
treated_np = X_in[idx_window, :] if is_inverse else None
batch = prepare_batch(
ctrl_basal_np=ctrl_basal,
pert_onehots=pert_oh,
batch_indices=bi,
pert_names=[p] * win_size,
device=model_device,
treated_np=treated_np,
)
batch_out = model.predict_step(batch, batch_idx=0, padded=False)

Expand Down
14 changes: 14 additions & 0 deletions src/state/configs/model/pdgrapher.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pdgrapher
checkpoint: null
device: cuda

kwargs:
mode: forward
edge_index_path: null
positional_features_dims: 16
embedding_layer_dim: 64
dim_gnn: 64
n_layers_gnn: 2
n_layers_nn: 2
dropout: 0.1
cell_set_len: 1
14 changes: 14 additions & 0 deletions src/state/configs/model/pdgrapher_inverse.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,14 @@
name: pdgrapher_inverse
checkpoint: null
device: cuda

kwargs:
mode: inverse
edge_index_path: null
positional_features_dims: 16
embedding_layer_dim: 64
dim_gnn: 64
n_layers_gnn: 2
n_layers_nn: 2
dropout: 0.1
cell_set_len: 1
2 changes: 2 additions & 0 deletions src/state/tx/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
from .old_neural_ot import OldNeuralOTPerturbationModel
from .state_transition import StateTransitionPerturbationModel
from .pseudobulk import PseudobulkPerturbationModel
from .pdgrapher import PDGrapherLightningModule

__all__ = [
"PerturbationModel",
Expand All @@ -16,4 +17,5 @@
"OldNeuralOTPerturbationModel",
"DecoderOnlyPerturbationModel",
"PseudobulkPerturbationModel",
"PDGrapherLightningModule",
]
5 changes: 5 additions & 0 deletions src/state/tx/models/pdgrapher/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
"""PDGrapher-inspired perturbation models."""

from .module import PDGrapherLightningModule

__all__ = ["PDGrapherLightningModule"]
Loading