Skip to content

Commit

Permalink
Merge pull request #102 from MrPrajwalB/develop
Browse files Browse the repository at this point in the history
Reduced focal loss implementation is incorrect for binary case
  • Loading branch information
BloodAxe authored Oct 11, 2024
2 parents c783ce8 + 654e1fd commit 61ad685
Showing 1 changed file with 1 addition and 1 deletion.
2 changes: 1 addition & 1 deletion pytorch_toolbelt/losses/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@ def focal_loss_with_logits(
if reduced_threshold is None:
focal_term = (1.0 - pt).pow(gamma)
else:
focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma)
focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides .
focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1)

loss = focal_term * ce_loss
Expand Down

0 comments on commit 61ad685

Please sign in to comment.