Skip to content

Dice Loss resulting in unexpected logit outputs #900

Open
@trchudley

Description

@trchudley

Hi, thanks for such a great package!

I'm currently training a simple U-net on a binary image classification problem (in this case, identifying water in satellite imagery). I am exploring different segmentation_models_pytorch loss functions to train the model. I have been using a focal loss function to initially build and test the model:

loss_function = smp.losses.FocalLoss(mode="binary", alpha=alpha, gamma=gamma)

After training the model, this model produces a relatively sensible output in both raw logit values and as a probability value after a sigmoid activation layer is applied:

focalloss

(NB: this is just a test run, the final accuracy doesn't really matter at this point)

I have been looking to explore and switch to dice loss. My understanding is that, using the from_logits variable, I could simply drop-and-replace the FocalLoss class with DiceLoss as follows:

loss_function = smp.losses.DiceLoss(mode="binary", from_logits=True)

Training the model using this DiceLoss class results in the following when applied to the same image:

diceloss

Looking at the logit output, this is great - the new dice-loss-trained model appears to qualitatively perform even better than the focal-loss-trained model! However, the raw outputs are not scaled around zero any more. Instead, raw outputs are all positive, scaled between approximately ~400 and ~9000 (depending on what image the model is being applied). As a result, applying a sigmoid activation does not create a nice probability distribution between zero and one - instead, the apparent probabilities are all now 1, due to the all-positive logit distribution.

I've examined the source code and I can't see anything that would result in such a difference. Am I missing something that results in DiceLoss not being a drag-and-drop replacement for FocalLoss to create probabilistic model predictions?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions