From ce05c47bcbb0f893e5f31dad8ff95f26d9236c5c Mon Sep 17 00:00:00 2001 From: Francis Chalissery <45127389+fctb12@users.noreply.github.com> Date: Wed, 15 Oct 2025 10:44:26 -0700 Subject: [PATCH] Add option to zero perturbation encoder in state transition --- src/state/configs/model/state.yaml | 1 + src/state/configs/model/state_lg.yaml | 1 + src/state/configs/model/state_sm.yaml | 1 + src/state/tx/models/state_transition.py | 3 +++ 4 files changed, 6 insertions(+) diff --git a/src/state/configs/model/state.yaml b/src/state/configs/model/state.yaml index e9b3e34d..b9507cf3 100644 --- a/src/state/configs/model/state.yaml +++ b/src/state/configs/model/state.yaml @@ -21,6 +21,7 @@ kwargs: nb_decoder: False mask_attn: False use_effect_gating_token: False + zero_perturbation_encoder: False distributional_loss: energy init_from: null transformer_backbone_key: llama diff --git a/src/state/configs/model/state_lg.yaml b/src/state/configs/model/state_lg.yaml index 89b875ab..598e1f73 100644 --- a/src/state/configs/model/state_lg.yaml +++ b/src/state/configs/model/state_lg.yaml @@ -21,6 +21,7 @@ kwargs: nb_decoder: False mask_attn: False use_effect_gating_token: False + zero_perturbation_encoder: False use_basal_projection: False distributional_loss: energy init_from: null diff --git a/src/state/configs/model/state_sm.yaml b/src/state/configs/model/state_sm.yaml index 77ddfd1f..eb359619 100644 --- a/src/state/configs/model/state_sm.yaml +++ b/src/state/configs/model/state_sm.yaml @@ -20,6 +20,7 @@ kwargs: nb_decoder: False mask_attn: False use_effect_gating_token: False + zero_perturbation_encoder: False use_basal_projection: False distributional_loss: energy gene_decoder_bool: False diff --git a/src/state/tx/models/state_transition.py b/src/state/tx/models/state_transition.py index ecce5e29..1f201ab4 100644 --- a/src/state/tx/models/state_transition.py +++ b/src/state/tx/models/state_transition.py @@ -182,6 +182,7 @@ def __init__( raise ValueError(f"Unknown loss function: {loss_name}") self.use_basal_projection = kwargs.get("use_basal_projection", True) + self.zero_perturbation_encoder = kwargs.get("zero_perturbation_encoder", False) # Build the underlying neural OT network self._build_networks(lora_cfg=kwargs.get("lora", None)) @@ -392,6 +393,8 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor: # Shape: [B, S, input_dim] pert_embedding = self.encode_perturbation(pert) + if self.zero_perturbation_encoder: + pert_embedding = torch.zeros_like(pert_embedding) control_cells = self.encode_basal_expression(basal) # Add encodings in input_dim space, then project to hidden_dim