Skip to content

Commit 2d71db1

Browse files
committed
udpated cell eval to 0.6
1 parent e8519a6 commit 2d71db1

File tree

3 files changed

+44
-50
lines changed

3 files changed

+44
-50
lines changed

pyproject.toml

Lines changed: 1 addition & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -27,13 +27,12 @@ dependencies = [
2727
"geomloss>=0.2.6",
2828
"transformers>=4.52.3",
2929
"peft>=0.11.0",
30-
"cell-eval>=0.5.46",
30+
"cell-eval>=0.6.0",
3131
"ipykernel>=6.30.1",
3232
]
3333

3434
[tool.uv.sources]
3535
cell-load = {path = "/home/aadduri/cell-load"}
36-
cell-eval = {path = "/home/aadduri/cell-eval"}
3736

3837
[project.optional-dependencies]
3938
vectordb = [

src/state/configs/model/state.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,7 +7,7 @@ kwargs:
77
blur: 0.05
88
hidden_dim: 768 # hidden dimension going into the transformer backbone
99
loss: energy
10-
confidence_head: False
10+
confidence_token: False
1111
n_encoder_layers: 1
1212
n_decoder_layers: 1
1313
predict_residual: True

src/state/tx/models/state_transition.py

Lines changed: 42 additions & 47 deletions
Original file line numberDiff line numberDiff line change
@@ -20,9 +20,7 @@
2020

2121

2222
class CombinedLoss(nn.Module):
23-
"""
24-
Combined Sinkhorn + Energy loss
25-
"""
23+
"""Combined Sinkhorn + Energy loss."""
2624

2725
def __init__(self, sinkhorn_weight=0.001, energy_weight=1.0, blur=0.05):
2826
super().__init__()
@@ -172,7 +170,7 @@ def __init__(
172170
elif loss_name == "mse":
173171
self.loss_fn = nn.MSELoss()
174172
elif loss_name == "se":
175-
sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01) # 1/100 = 0.01
173+
sinkhorn_weight = kwargs.get("sinkhorn_weight", 0.01)
176174
energy_weight = kwargs.get("energy_weight", 1.0)
177175
self.loss_fn = CombinedLoss(sinkhorn_weight=sinkhorn_weight, energy_weight=energy_weight, blur=blur)
178176
elif loss_name == "sinkhorn":
@@ -246,6 +244,11 @@ def __init__(
246244
if kwargs.get("confidence_token", False):
247245
self.confidence_token = ConfidenceToken(hidden_dim=self.hidden_dim, dropout=self.dropout)
248246
self.confidence_loss_fn = nn.MSELoss()
247+
self.confidence_target_scale = float(kwargs.get("confidence_target_scale", 10.0))
248+
self.confidence_weight = float(kwargs.get("confidence_weight", 0.01))
249+
else:
250+
self.confidence_target_scale = None
251+
self.confidence_weight = 0.0
249252

250253
# Backward-compat: accept legacy key `freeze_pert`
251254
self.freeze_pert_backbone = kwargs.get("freeze_pert_backbone", kwargs.get("freeze_pert", False))
@@ -482,7 +485,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
482485
pred = pred.reshape(1, -1, self.output_dim)
483486
target = target.reshape(1, -1, self.output_dim)
484487

485-
main_loss = self.loss_fn(pred, target).nanmean()
488+
per_set_main_losses = self.loss_fn(pred, target)
489+
main_loss = torch.nanmean(per_set_main_losses)
486490
self.log("train_loss", main_loss)
487491

488492
# Log individual loss components if using combined loss
@@ -554,25 +558,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
554558
total_loss = total_loss + self.decoder_loss_weight * decoder_loss
555559

556560
if confidence_pred is not None:
557-
# Detach main loss to prevent gradients flowing through it
558-
loss_target = total_loss.detach().clone().unsqueeze(0) * 10
559-
560-
# Ensure proper shapes for confidence loss computation
561-
if confidence_pred.dim() == 2: # [B, 1]
562-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
563-
else: # confidence_pred is [B,]
564-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
565-
566-
# Compute confidence loss
567-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
561+
confidence_pred_vals = confidence_pred
562+
if confidence_pred_vals.dim() > 1:
563+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
564+
confidence_targets = per_set_main_losses.detach()
565+
if self.confidence_target_scale is not None:
566+
confidence_targets = confidence_targets * self.confidence_target_scale
567+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
568+
569+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
568570
self.log("train/confidence_loss", confidence_loss)
569-
self.log("train/actual_loss", loss_target.mean())
571+
self.log("train/actual_loss", confidence_targets.mean())
570572

571-
# Add to total loss with weighting
572-
confidence_weight = 0.1 # You can make this configurable
573-
total_loss = total_loss + confidence_weight * confidence_loss
574-
575-
# Add to total loss
576573
total_loss = total_loss + confidence_loss
577574

578575
if self.regularization > 0.0:
@@ -601,7 +598,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
601598
target = batch["pert_cell_emb"]
602599
target = target.reshape(-1, self.cell_sentence_len, self.output_dim)
603600

604-
loss = self.loss_fn(pred, target).mean()
601+
per_set_main_losses = self.loss_fn(pred, target)
602+
loss = torch.nanmean(per_set_main_losses)
605603
self.log("val_loss", loss)
606604

607605
# Log individual loss components if using combined loss
@@ -653,19 +651,17 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
653651
loss = loss + self.decoder_loss_weight * decoder_loss
654652

655653
if confidence_pred is not None:
656-
# Detach main loss to prevent gradients flowing through it
657-
loss_target = loss.detach().clone() * 10
658-
659-
# Ensure proper shapes for confidence loss computation
660-
if confidence_pred.dim() == 2: # [B, 1]
661-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
662-
else: # confidence_pred is [B,]
663-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
664-
665-
# Compute confidence loss
666-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
654+
confidence_pred_vals = confidence_pred
655+
if confidence_pred_vals.dim() > 1:
656+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
657+
confidence_targets = per_set_main_losses.detach()
658+
if self.confidence_target_scale is not None:
659+
confidence_targets = confidence_targets * self.confidence_target_scale
660+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
661+
662+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
667663
self.log("val/confidence_loss", confidence_loss)
668-
self.log("val/actual_loss", loss_target.mean())
664+
self.log("val/actual_loss", confidence_targets.mean())
669665

670666
return {"loss": loss, "predictions": pred}
671667

@@ -678,21 +674,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
678674
target = batch["pert_cell_emb"]
679675
pred = pred.reshape(1, -1, self.output_dim)
680676
target = target.reshape(1, -1, self.output_dim)
681-
loss = self.loss_fn(pred, target).mean()
677+
per_set_main_losses = self.loss_fn(pred, target)
678+
loss = torch.nanmean(per_set_main_losses)
682679
self.log("test_loss", loss)
683680

684681
if confidence_pred is not None:
685-
# Detach main loss to prevent gradients flowing through it
686-
loss_target = loss.detach().clone() * 10.0
687-
688-
# Ensure proper shapes for confidence loss computation
689-
if confidence_pred.dim() == 2: # [B, 1]
690-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0), 1)
691-
else: # confidence_pred is [B,]
692-
loss_target = loss_target.unsqueeze(0).expand(confidence_pred.size(0))
693-
694-
# Compute confidence loss
695-
confidence_loss = self.confidence_loss_fn(confidence_pred.squeeze(), loss_target.squeeze())
682+
confidence_pred_vals = confidence_pred
683+
if confidence_pred_vals.dim() > 1:
684+
confidence_pred_vals = confidence_pred_vals.squeeze(-1)
685+
confidence_targets = per_set_main_losses.detach()
686+
if self.confidence_target_scale is not None:
687+
confidence_targets = confidence_targets * self.confidence_target_scale
688+
confidence_targets = confidence_targets.to(confidence_pred_vals.device)
689+
690+
confidence_loss = self.confidence_weight * self.confidence_loss_fn(confidence_pred_vals, confidence_targets)
696691
self.log("test/confidence_loss", confidence_loss)
697692

698693
def predict_step(self, batch, batch_idx, padded=True, **kwargs):

0 commit comments

Comments
 (0)