From 2c783052ca470dca053f40096e6fd4cfebbb519b Mon Sep 17 00:00:00 2001 From: Prajwal Bende <34344073+MrPrajwalB@users.noreply.github.com> Date: Thu, 10 Oct 2024 01:38:21 +0530 Subject: [PATCH 1/2] Update functional.py Fixed Binary focal loss when reduced threshold is not None. Focal term should be continuous and equal to 1 at pt equal to reduced threshold --- pytorch_toolbelt/losses/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 1ef62f055..34ccdd55e 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -70,7 +70,7 @@ def focal_loss_with_logits( if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: - focal_term = ((1.0 - pt) / reduced_threshold).pow(gamma) + focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold < 0.5. At pt == reduced_threshold, focal term should be 1 from both sides. focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) loss = focal_term * ce_loss From 654e1fd0153922733374d3c86159944522472787 Mon Sep 17 00:00:00 2001 From: Prajwal Bende <34344073+MrPrajwalB@users.noreply.github.com> Date: Thu, 10 Oct 2024 01:41:36 +0530 Subject: [PATCH 2/2] Update functional.py --- pytorch_toolbelt/losses/functional.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/pytorch_toolbelt/losses/functional.py b/pytorch_toolbelt/losses/functional.py index 34ccdd55e..daae4a052 100644 --- a/pytorch_toolbelt/losses/functional.py +++ b/pytorch_toolbelt/losses/functional.py @@ -70,7 +70,7 @@ def focal_loss_with_logits( if reduced_threshold is None: focal_term = (1.0 - pt).pow(gamma) else: - focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold < 0.5. At pt == reduced_threshold, focal term should be 1 from both sides. + focal_term = ((1.0 - pt) / (1 - reduced_threshold)).pow(gamma) #the focal term continuity breaks when reduced_threshold not equal to 0.5. At pt equal to reduced_threshold, the value of piecewise function of focal term should be 1 from both sides . focal_term = torch.masked_fill(focal_term, pt < reduced_threshold, 1) loss = focal_term * ce_loss