diff --git a/sentence_transformers/losses/CachedGISTEmbedLoss.py b/sentence_transformers/losses/CachedGISTEmbedLoss.py index e8e131582..51bc1c02b 100644 --- a/sentence_transformers/losses/CachedGISTEmbedLoss.py +++ b/sentence_transformers/losses/CachedGISTEmbedLoss.py @@ -265,7 +265,7 @@ def calculate_loss_and_cache_gradients(self, reps: list[list[Tensor]], reps_guid # Apply thresholds based on guided model similarities ap_sim[guided_ap_sim > guided_sim] = -torch.inf aa_sim[guided_aa_sim > guided_sim] = -torch.inf - pp_sim[guided_pp_sim > guided_sim] = -torch.inf + pp_sim[guided_pp_sim >= guided_sim] = -torch.inf # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) @@ -328,7 +328,7 @@ def calculate_loss(self, reps: list[list[Tensor]], reps_guided: list[list[Tensor # Apply thresholds based on guided model similarities ap_sim[guided_ap_sim > guided_sim] = -torch.inf aa_sim[guided_aa_sim > guided_sim] = -torch.inf - pp_sim[guided_pp_sim > guided_sim] = -torch.inf + pp_sim[guided_pp_sim >= guided_sim] = -torch.inf # Concatenate the similarity matrices for anchor-positive, anchor-anchor, and positive-positive scores = torch.cat([ap_sim, aa_sim, pp_sim], dim=1) diff --git a/sentence_transformers/losses/GISTEmbedLoss.py b/sentence_transformers/losses/GISTEmbedLoss.py index 5111b2b0c..4b79b39f6 100644 --- a/sentence_transformers/losses/GISTEmbedLoss.py +++ b/sentence_transformers/losses/GISTEmbedLoss.py @@ -150,7 +150,7 @@ def forward(self, sentence_features: Iterable[dict[str, Tensor]], labels: Tensor # the loss. ap_sim[guided_ap_sim > guided_sim] = -torch.inf aa_sim[guided_aa_sim > guided_sim] = -torch.inf - pp_sim[guided_pp_sim > guided_sim] = -torch.inf + pp_sim[guided_pp_sim >= guided_sim] = -torch.inf scores = [ap_sim, aa_sim, pp_sim]