Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
114 changes: 65 additions & 49 deletions src/state/tx/models/state_transition.py
Original file line number Diff line number Diff line change
Expand Up @@ -369,6 +369,62 @@ def encode_basal_expression(self, expr: torch.Tensor) -> torch.Tensor:
"""Define how we embed basal state input, if needed."""
return self.basal_encoder(expr)

def _compute_batch_token_loss(self, batch: Dict[str, torch.Tensor], padded: bool) -> Optional[torch.Tensor]:
"""Compute CE loss for the optional batch token from cached token features.

Returns None if batch token training is disabled or cache is unavailable.
"""
if not (self.use_batch_token and self.batch_classifier is not None and self._batch_token_cache is not None):
return None

logits = self.batch_classifier(self._batch_token_cache) # [B, 1, C]
batch_token_targets = batch["batch"]

B = logits.shape[0]
C = logits.size(-1)

# Prepare one label per sequence (all S cells share the same batch)
if batch_token_targets.dim() > 1 and batch_token_targets.size(-1) == C:
# One-hot labels; reshape to [B, S, C]
if padded:
target_oh = batch_token_targets.reshape(-1, self.cell_sentence_len, C)
else:
target_oh = batch_token_targets.reshape(1, -1, C)
sentence_batch_labels = target_oh.argmax(-1)
else:
# Integer labels; reshape to [B, S]
if padded:
sentence_batch_labels = batch_token_targets.reshape(-1, self.cell_sentence_len)
else:
sentence_batch_labels = batch_token_targets.reshape(1, -1)

if sentence_batch_labels.shape[0] != B:
sentence_batch_labels = sentence_batch_labels.reshape(B, -1)

if self.basal_mapping_strategy == "batch":
uniform_mask = sentence_batch_labels.eq(sentence_batch_labels[:, :1]).all(dim=1)
if not torch.all(uniform_mask):
bad_indices = torch.where(~uniform_mask)[0]
label_strings = []
for idx in bad_indices:
labels = sentence_batch_labels[idx].detach().cpu().tolist()
logger.error("Batch labels for sentence %d: %s", idx.item(), labels)
label_strings.append(f"sentence {idx.item()}: {labels}")
raise ValueError(
"Expected all cells in a sentence to share the same batch when "
"basal_mapping_strategy is 'batch'. "
f"Found mixed batch labels: {', '.join(label_strings)}"
)

target_idx = sentence_batch_labels[:, 0]

# Safety: ensure exactly one target per sequence
if target_idx.numel() != B:
target_idx = target_idx.reshape(-1)[:B]

ce_loss = F.cross_entropy(logits.reshape(B, -1, C).squeeze(1), target_idx.long())
return ce_loss

def forward(self, batch: dict, padded=True) -> torch.Tensor:
"""
The main forward call. Batch is a flattened sequence of cell sentences,
Expand Down Expand Up @@ -431,11 +487,11 @@ def forward(self, batch: dict, padded=True) -> torch.Tensor:
if self.hparams.get("mask_attn", False):
batch_size, seq_length, _ = seq_input.shape
device = seq_input.device
self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]
self.transformer_backbone._attn_implementation = "eager" # pyright: ignore[reportAttributeAccessIssue, reportArgumentType]

# create a [1,1,S,S] mask (now S+1 if confidence token is used)
base = torch.eye(seq_length, device=device, dtype=torch.bool).view(1, 1, seq_length, seq_length)

# Get number of attention heads from model config
num_heads = self.transformer_backbone.config.num_attention_heads

Expand Down Expand Up @@ -529,53 +585,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
decoder_loss = None
total_loss = main_loss

if self.use_batch_token and self.batch_classifier is not None and self._batch_token_cache is not None:
logits = self.batch_classifier(self._batch_token_cache) # [B, 1, C]
batch_token_targets = batch["batch"]

B = logits.shape[0]
C = logits.size(-1)

# Prepare one label per sequence (all S cells share the same batch)
if batch_token_targets.dim() > 1 and batch_token_targets.size(-1) == C:
# One-hot labels; reshape to [B, S, C]
if padded:
target_oh = batch_token_targets.reshape(-1, self.cell_sentence_len, C)
else:
target_oh = batch_token_targets.reshape(1, -1, C)
sentence_batch_labels = target_oh.argmax(-1)
else:
# Integer labels; reshape to [B, S]
if padded:
sentence_batch_labels = batch_token_targets.reshape(-1, self.cell_sentence_len)
else:
sentence_batch_labels = batch_token_targets.reshape(1, -1)

if sentence_batch_labels.shape[0] != B:
sentence_batch_labels = sentence_batch_labels.reshape(B, -1)

if self.basal_mapping_strategy == "batch":
uniform_mask = sentence_batch_labels.eq(sentence_batch_labels[:, :1]).all(dim=1)
if not torch.all(uniform_mask):
bad_indices = torch.where(~uniform_mask)[0]
label_strings = []
for idx in bad_indices:
labels = sentence_batch_labels[idx].detach().cpu().tolist()
logger.error("Batch labels for sentence %d: %s", idx.item(), labels)
label_strings.append(f"sentence {idx.item()}: {labels}")
raise ValueError(
"Expected all cells in a sentence to share the same batch when "
"basal_mapping_strategy is 'batch'. "
f"Found mixed batch labels: {', '.join(label_strings)}"
)

target_idx = sentence_batch_labels[:, 0]

# Safety: ensure exactly one target per sequence
if target_idx.numel() != B:
target_idx = target_idx.reshape(-1)[:B]

ce_loss = F.cross_entropy(logits.reshape(B, -1, C).squeeze(1), target_idx.long())
ce_loss = self._compute_batch_token_loss(batch, padded=padded)
if ce_loss is not None:
self.log("train/batch_token_loss", ce_loss)
total_loss = total_loss + self.batch_token_weight * ce_loss

Expand Down Expand Up @@ -668,6 +679,11 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
self.log("val/sinkhorn_loss", sinkhorn_component)
self.log("val/energy_loss", energy_component)

# Log batch token loss during validation without adding to validation loss
ce_loss_val = self._compute_batch_token_loss(batch, padded=True)
if ce_loss_val is not None:
self.log("val/batch_token_loss", ce_loss_val)

if self.gene_decoder is not None and "pert_cell_counts" in batch:
gene_targets = batch["pert_cell_counts"]

Expand Down