2020
2121
2222class CombinedLoss (nn .Module ):
23- """
24- Combined Sinkhorn + Energy loss
25- """
23+ """Combined Sinkhorn + Energy loss."""
2624
2725 def __init__ (self , sinkhorn_weight = 0.001 , energy_weight = 1.0 , blur = 0.05 ):
2826 super ().__init__ ()
@@ -172,7 +170,7 @@ def __init__(
172170 elif loss_name == "mse" :
173171 self .loss_fn = nn .MSELoss ()
174172 elif loss_name == "se" :
175- sinkhorn_weight = kwargs .get ("sinkhorn_weight" , 0.01 ) # 1/100 = 0.01
173+ sinkhorn_weight = kwargs .get ("sinkhorn_weight" , 0.01 )
176174 energy_weight = kwargs .get ("energy_weight" , 1.0 )
177175 self .loss_fn = CombinedLoss (sinkhorn_weight = sinkhorn_weight , energy_weight = energy_weight , blur = blur )
178176 elif loss_name == "sinkhorn" :
@@ -246,6 +244,11 @@ def __init__(
246244 if kwargs .get ("confidence_token" , False ):
247245 self .confidence_token = ConfidenceToken (hidden_dim = self .hidden_dim , dropout = self .dropout )
248246 self .confidence_loss_fn = nn .MSELoss ()
247+ self .confidence_target_scale = float (kwargs .get ("confidence_target_scale" , 10.0 ))
248+ self .confidence_weight = float (kwargs .get ("confidence_weight" , 0.01 ))
249+ else :
250+ self .confidence_target_scale = None
251+ self .confidence_weight = 0.0
249252
250253 # Backward-compat: accept legacy key `freeze_pert`
251254 self .freeze_pert_backbone = kwargs .get ("freeze_pert_backbone" , kwargs .get ("freeze_pert" , False ))
@@ -482,7 +485,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
482485 pred = pred .reshape (1 , - 1 , self .output_dim )
483486 target = target .reshape (1 , - 1 , self .output_dim )
484487
485- main_loss = self .loss_fn (pred , target ).nanmean ()
488+ per_set_main_losses = self .loss_fn (pred , target )
489+ main_loss = torch .nanmean (per_set_main_losses )
486490 self .log ("train_loss" , main_loss )
487491
488492 # Log individual loss components if using combined loss
@@ -554,25 +558,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
554558 total_loss = total_loss + self .decoder_loss_weight * decoder_loss
555559
556560 if confidence_pred is not None :
557- # Detach main loss to prevent gradients flowing through it
558- loss_target = total_loss .detach ().clone ().unsqueeze (0 ) * 10
559-
560- # Ensure proper shapes for confidence loss computation
561- if confidence_pred .dim () == 2 : # [B, 1]
562- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
563- else : # confidence_pred is [B,]
564- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
565-
566- # Compute confidence loss
567- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
561+ confidence_pred_vals = confidence_pred
562+ if confidence_pred_vals .dim () > 1 :
563+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
564+ confidence_targets = per_set_main_losses .detach ()
565+ if self .confidence_target_scale is not None :
566+ confidence_targets = confidence_targets * self .confidence_target_scale
567+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
568+
569+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
568570 self .log ("train/confidence_loss" , confidence_loss )
569- self .log ("train/actual_loss" , loss_target .mean ())
571+ self .log ("train/actual_loss" , confidence_targets .mean ())
570572
571- # Add to total loss with weighting
572- confidence_weight = 0.1 # You can make this configurable
573- total_loss = total_loss + confidence_weight * confidence_loss
574-
575- # Add to total loss
576573 total_loss = total_loss + confidence_loss
577574
578575 if self .regularization > 0.0 :
@@ -601,7 +598,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
601598 target = batch ["pert_cell_emb" ]
602599 target = target .reshape (- 1 , self .cell_sentence_len , self .output_dim )
603600
604- loss = self .loss_fn (pred , target ).mean ()
601+ per_set_main_losses = self .loss_fn (pred , target )
602+ loss = torch .nanmean (per_set_main_losses )
605603 self .log ("val_loss" , loss )
606604
607605 # Log individual loss components if using combined loss
@@ -653,19 +651,17 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
653651 loss = loss + self .decoder_loss_weight * decoder_loss
654652
655653 if confidence_pred is not None :
656- # Detach main loss to prevent gradients flowing through it
657- loss_target = loss .detach ().clone () * 10
658-
659- # Ensure proper shapes for confidence loss computation
660- if confidence_pred .dim () == 2 : # [B, 1]
661- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
662- else : # confidence_pred is [B,]
663- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
664-
665- # Compute confidence loss
666- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
654+ confidence_pred_vals = confidence_pred
655+ if confidence_pred_vals .dim () > 1 :
656+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
657+ confidence_targets = per_set_main_losses .detach ()
658+ if self .confidence_target_scale is not None :
659+ confidence_targets = confidence_targets * self .confidence_target_scale
660+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
661+
662+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
667663 self .log ("val/confidence_loss" , confidence_loss )
668- self .log ("val/actual_loss" , loss_target .mean ())
664+ self .log ("val/actual_loss" , confidence_targets .mean ())
669665
670666 return {"loss" : loss , "predictions" : pred }
671667
@@ -678,21 +674,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
678674 target = batch ["pert_cell_emb" ]
679675 pred = pred .reshape (1 , - 1 , self .output_dim )
680676 target = target .reshape (1 , - 1 , self .output_dim )
681- loss = self .loss_fn (pred , target ).mean ()
677+ per_set_main_losses = self .loss_fn (pred , target )
678+ loss = torch .nanmean (per_set_main_losses )
682679 self .log ("test_loss" , loss )
683680
684681 if confidence_pred is not None :
685- # Detach main loss to prevent gradients flowing through it
686- loss_target = loss .detach ().clone () * 10.0
687-
688- # Ensure proper shapes for confidence loss computation
689- if confidence_pred .dim () == 2 : # [B, 1]
690- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
691- else : # confidence_pred is [B,]
692- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
693-
694- # Compute confidence loss
695- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
682+ confidence_pred_vals = confidence_pred
683+ if confidence_pred_vals .dim () > 1 :
684+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
685+ confidence_targets = per_set_main_losses .detach ()
686+ if self .confidence_target_scale is not None :
687+ confidence_targets = confidence_targets * self .confidence_target_scale
688+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
689+
690+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
696691 self .log ("test/confidence_loss" , confidence_loss )
697692
698693 def predict_step (self , batch , batch_idx , padded = True , ** kwargs ):
0 commit comments