2020
2121
2222class CombinedLoss (nn .Module ):
23- """
24- Combined Sinkhorn + Energy loss
25- """
23+ """Combined Sinkhorn + Energy loss."""
2624
2725 def __init__ (self , sinkhorn_weight = 0.001 , energy_weight = 1.0 , blur = 0.05 ):
2826 super ().__init__ ()
@@ -173,7 +171,7 @@ def __init__(
173171 elif loss_name == "mse" :
174172 self .loss_fn = nn .MSELoss ()
175173 elif loss_name == "se" :
176- sinkhorn_weight = kwargs .get ("sinkhorn_weight" , 0.01 ) # 1/100 = 0.01
174+ sinkhorn_weight = kwargs .get ("sinkhorn_weight" , 0.01 )
177175 energy_weight = kwargs .get ("energy_weight" , 1.0 )
178176 self .loss_fn = CombinedLoss (sinkhorn_weight = sinkhorn_weight , energy_weight = energy_weight , blur = blur )
179177 elif loss_name == "sinkhorn" :
@@ -288,6 +286,11 @@ def __init__(
288286 if kwargs .get ("confidence_token" , False ):
289287 self .confidence_token = ConfidenceToken (hidden_dim = self .hidden_dim , dropout = self .dropout )
290288 self .confidence_loss_fn = nn .MSELoss ()
289+ self .confidence_target_scale = float (kwargs .get ("confidence_target_scale" , 10.0 ))
290+ self .confidence_weight = float (kwargs .get ("confidence_weight" , 0.01 ))
291+ else :
292+ self .confidence_target_scale = None
293+ self .confidence_weight = 0.0
291294
292295 # Backward-compat: accept legacy key `freeze_pert`
293296 self .freeze_pert_backbone = kwargs .get ("freeze_pert_backbone" , kwargs .get ("freeze_pert" , False ))
@@ -544,7 +547,8 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
544547 pred = pred .reshape (1 , - 1 , self .output_dim )
545548 target = target .reshape (1 , - 1 , self .output_dim )
546549
547- main_loss = self .loss_fn (pred , target ).nanmean ()
550+ per_set_main_losses = self .loss_fn (pred , target )
551+ main_loss = torch .nanmean (per_set_main_losses )
548552 self .log ("train_loss" , main_loss )
549553
550554 # Log individual loss components if using combined loss
@@ -641,25 +645,18 @@ def training_step(self, batch: Dict[str, torch.Tensor], batch_idx: int, padded=T
641645 total_loss = total_loss + self .decoder_loss_weight * decoder_loss
642646
643647 if confidence_pred is not None :
644- # Detach main loss to prevent gradients flowing through it
645- loss_target = total_loss .detach ().clone ().unsqueeze (0 ) * 10
646-
647- # Ensure proper shapes for confidence loss computation
648- if confidence_pred .dim () == 2 : # [B, 1]
649- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
650- else : # confidence_pred is [B,]
651- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
652-
653- # Compute confidence loss
654- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
648+ confidence_pred_vals = confidence_pred
649+ if confidence_pred_vals .dim () > 1 :
650+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
651+ confidence_targets = per_set_main_losses .detach ()
652+ if self .confidence_target_scale is not None :
653+ confidence_targets = confidence_targets * self .confidence_target_scale
654+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
655+
656+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
655657 self .log ("train/confidence_loss" , confidence_loss )
656- self .log ("train/actual_loss" , loss_target .mean ())
658+ self .log ("train/actual_loss" , confidence_targets .mean ())
657659
658- # Add to total loss with weighting
659- confidence_weight = 0.1 # You can make this configurable
660- total_loss = total_loss + confidence_weight * confidence_loss
661-
662- # Add to total loss
663660 total_loss = total_loss + confidence_loss
664661
665662 if self .regularization > 0.0 :
@@ -688,7 +685,8 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
688685 target = batch ["pert_cell_emb" ]
689686 target = target .reshape (- 1 , self .cell_sentence_len , self .output_dim )
690687
691- loss = self .loss_fn (pred , target ).mean ()
688+ per_set_main_losses = self .loss_fn (pred , target )
689+ loss = torch .nanmean (per_set_main_losses )
692690 self .log ("val_loss" , loss )
693691
694692 # Log individual loss components if using combined loss
@@ -722,19 +720,17 @@ def validation_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> Non
722720 loss = loss + self .decoder_loss_weight * decoder_loss
723721
724722 if confidence_pred is not None :
725- # Detach main loss to prevent gradients flowing through it
726- loss_target = loss .detach ().clone () * 10
727-
728- # Ensure proper shapes for confidence loss computation
729- if confidence_pred .dim () == 2 : # [B, 1]
730- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
731- else : # confidence_pred is [B,]
732- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
733-
734- # Compute confidence loss
735- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
723+ confidence_pred_vals = confidence_pred
724+ if confidence_pred_vals .dim () > 1 :
725+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
726+ confidence_targets = per_set_main_losses .detach ()
727+ if self .confidence_target_scale is not None :
728+ confidence_targets = confidence_targets * self .confidence_target_scale
729+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
730+
731+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
736732 self .log ("val/confidence_loss" , confidence_loss )
737- self .log ("val/actual_loss" , loss_target .mean ())
733+ self .log ("val/actual_loss" , confidence_targets .mean ())
738734
739735 return {"loss" : loss , "predictions" : pred }
740736
@@ -747,21 +743,20 @@ def test_step(self, batch: Dict[str, torch.Tensor], batch_idx: int) -> None:
747743 target = batch ["pert_cell_emb" ]
748744 pred = pred .reshape (1 , - 1 , self .output_dim )
749745 target = target .reshape (1 , - 1 , self .output_dim )
750- loss = self .loss_fn (pred , target ).mean ()
746+ per_set_main_losses = self .loss_fn (pred , target )
747+ loss = torch .nanmean (per_set_main_losses )
751748 self .log ("test_loss" , loss )
752749
753750 if confidence_pred is not None :
754- # Detach main loss to prevent gradients flowing through it
755- loss_target = loss .detach ().clone () * 10.0
756-
757- # Ensure proper shapes for confidence loss computation
758- if confidence_pred .dim () == 2 : # [B, 1]
759- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ), 1 )
760- else : # confidence_pred is [B,]
761- loss_target = loss_target .unsqueeze (0 ).expand (confidence_pred .size (0 ))
762-
763- # Compute confidence loss
764- confidence_loss = self .confidence_loss_fn (confidence_pred .squeeze (), loss_target .squeeze ())
751+ confidence_pred_vals = confidence_pred
752+ if confidence_pred_vals .dim () > 1 :
753+ confidence_pred_vals = confidence_pred_vals .squeeze (- 1 )
754+ confidence_targets = per_set_main_losses .detach ()
755+ if self .confidence_target_scale is not None :
756+ confidence_targets = confidence_targets * self .confidence_target_scale
757+ confidence_targets = confidence_targets .to (confidence_pred_vals .device )
758+
759+ confidence_loss = self .confidence_weight * self .confidence_loss_fn (confidence_pred_vals , confidence_targets )
765760 self .log ("test/confidence_loss" , confidence_loss )
766761
767762 def predict_step (self , batch , batch_idx , padded = True , ** kwargs ):
0 commit comments