Skip to content

Commit 19933ed

Browse files
committed
udpated cell eval to 0.6
1 parent 0df3dbc commit 19933ed

File tree

3 files changed

+44
-50
lines changed

3 files changed

+44
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,14 +27,13 @@ dependencies = [
2727
"geomloss>=0.2.6",
2828
"transformers>=4.52.3",
2929
"peft>=0.11.0",
30-
"cell-eval>=0.5.46",
30+
"cell-eval>=0.6.0",
3131
"ipykernel>=6.30.1",
3232
"scipy>=1.15.0",
3333
]
3434

3535
[tool.uv.sources]
3636
cell-load = {path = "/home/aadduri/cell-load"}
37-
cell-eval = {path = "/home/aadduri/cell-eval"}
3837

3938
[project.optional-dependencies]
4039
vectordb = [

src/state/configs/model/state.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ kwargs:
77
blur: 0.05
88
hidden_dim: 768 # hidden dimension going into the transformer backbone
99
loss: energy
10-
confidence_head: False
10+
confidence_token: False
1111
n_encoder_layers: 1
1212
n_decoder_layers: 1
1313
predict_residual: True

src/state/tx/models/state_transition.py

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121

2222
class CombinedLoss(nn.Module):
23-
"""
24-
Combined Sinkhorn + Energy loss
25-
"""
23+
"""Combined Sinkhorn + Energy loss."""
2624

2725
def __init__(self, sinkhorn_weight=0.001, energy_weight=1.0, blur=0.05):
2826
super().__init__()
@@ -173,7 +171,7 @@ def __init__(
173171
elif loss_name == "mse":
174172
self.loss_fn = nn.MSELoss()
175173
elif loss_name == "se":
176-
sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) # 1/100 = 0.01
174+
sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01)
177175
energy_weight = kwargs.get("energy_weight", 1.0)
178176
self.loss_fn = CombinedLoss(sinkhorn_weight=sinkhorn_weight, energy_weight=energy_weight, blur=blur)
179177
elif loss_name == "sinkhorn":
@@ -288,6 +286,11 @@ def __init__(
288286
if kwargs.get("confidence_token", False):
289287
self.confidence_token = ConfidenceToken(hidden_dim=self.hidden_dim, dropout=self.dropout)
290288
self.confidence_loss_fn = nn.MSELoss()
289+
self.confidence_target_scale = float(kwargs.get("confidence_target_scale", 10.0))
290+
self.confidence_weight = float(kwargs.get("confidence_weight", 0.01))
291+
else:
292+
self.confidence_target_scale = None
293+
self.confidence_weight = 0.0
291294

292295
# Backward-compat: accept legacy key `freeze_pert`
293296
self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False))
@@ -544,7 +547,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
544547
pred = pred.reshape(1, -1, self.output_dim)
545548
target = target.reshape(1, -1, self.output_dim)
546549

547-
main_loss = self.loss_fn(pred, target).nanmean()
550+
per_set_main_losses = self.loss_fn(pred, target)
551+
main_loss = torch.nanmean(per_set_main_losses)
548552
self.log("train_loss", main_loss)
549553

550554
# Log individual loss components if using combined loss
@@ -641,25 +645,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
641645
total_loss = total_loss + self.decoder_loss_weight * decoder_loss
642646

643647
if confidence_pred is not None:
644-
# Detach main loss to prevent gradients flowing through it
645-
loss_target = total_loss.detach().clone().unsqueeze(0) * 10
646-
647-
# Ensure proper shapes for confidence loss computation
648-
if confidence_pred.dim() == 2: # [B, 1]
649-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
650-
else: # confidence_pred is [B,]
651-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
652-
653-
# Compute confidence loss
654-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
648+
confidence_pred_vals = confidence_pred
649+
if confidence_pred_vals.dim() > 1:
650+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
651+
confidence_targets = per_set_main_losses.detach()
652+
if self.confidence_target_scale is not None:
653+
confidence_targets = confidence_targets * self.confidence_target_scale
654+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
655+
656+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
655657
self.log("train/confidence_loss", confidence_loss)
656-
self.log("train/actual_loss", loss_target.mean())
658+
self.log("train/actual_loss", confidence_targets.mean())
657659

658-
# Add to total loss with weighting
659-
confidence_weight = 0.1 # You can make this configurable
660-
total_loss = total_loss + confidence_weight * confidence_loss
661-
662-
# Add to total loss
663660
total_loss = total_loss + confidence_loss
664661

665662
if self.regularization > 0.0:
@@ -688,7 +685,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
688685
target = batch["pert_cell_emb"]
689686
target = target.reshape(-1, self.cell_sentence_len, self.output_dim)
690687

691-
loss = self.loss_fn(pred, target).mean()
688+
per_set_main_losses = self.loss_fn(pred, target)
689+
loss = torch.nanmean(per_set_main_losses)
692690
self.log("val_loss", loss)
693691

694692
# Log individual loss components if using combined loss
@@ -722,19 +720,17 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
722720
loss = loss + self.decoder_loss_weight * decoder_loss
723721

724722
if confidence_pred is not None:
725-
# Detach main loss to prevent gradients flowing through it
726-
loss_target = loss.detach().clone() * 10
727-
728-
# Ensure proper shapes for confidence loss computation
729-
if confidence_pred.dim() == 2: # [B, 1]
730-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
731-
else: # confidence_pred is [B,]
732-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
733-
734-
# Compute confidence loss
735-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
723+
confidence_pred_vals = confidence_pred
724+
if confidence_pred_vals.dim() > 1:
725+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
726+
confidence_targets = per_set_main_losses.detach()
727+
if self.confidence_target_scale is not None:
728+
confidence_targets = confidence_targets * self.confidence_target_scale
729+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
730+
731+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
736732
self.log("val/confidence_loss", confidence_loss)
737-
self.log("val/actual_loss", loss_target.mean())
733+
self.log("val/actual_loss", confidence_targets.mean())
738734

739735
return {"loss": loss, "predictions": pred}
740736

@@ -747,21 +743,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
747743
target = batch["pert_cell_emb"]
748744
pred = pred.reshape(1, -1, self.output_dim)
749745
target = target.reshape(1, -1, self.output_dim)
750-
loss = self.loss_fn(pred, target).mean()
746+
per_set_main_losses = self.loss_fn(pred, target)
747+
loss = torch.nanmean(per_set_main_losses)
751748
self.log("test_loss", loss)
752749

753750
if confidence_pred is not None:
754-
# Detach main loss to prevent gradients flowing through it
755-
loss_target = loss.detach().clone() * 10.0
756-
757-
# Ensure proper shapes for confidence loss computation
758-
if confidence_pred.dim() == 2: # [B, 1]
759-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
760-
else: # confidence_pred is [B,]
761-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
762-
763-
# Compute confidence loss
764-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
751+
confidence_pred_vals = confidence_pred
752+
if confidence_pred_vals.dim() > 1:
753+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
754+
confidence_targets = per_set_main_losses.detach()
755+
if self.confidence_target_scale is not None:
756+
confidence_targets = confidence_targets * self.confidence_target_scale
757+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
758+
759+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
765760
self.log("test/confidence_loss", confidence_loss)
766761

767762
def predict_step(self, batch, batch_idx, padded=True, **kwargs):

0 commit comments

Comments
 (0)