Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Correct behavior of ignore_index for JaccardLoss #1898

Merged
merged 11 commits into from
Feb 25, 2024
Merged

Conversation

calebrob6
Copy link
Member

@calebrob6 calebrob6 commented Feb 22, 2024

Fixes #1896
Fixes #1891

@github-actions github-actions bot added the trainers PyTorch Lightning trainers label Feb 22, 2024
@DimitrisMantas
Copy link
Contributor

DimitrisMantas commented Feb 22, 2024

Hi, I'm the author of #1891.

Thanks for taking the time to fix the issue, and, most importantly, not roast the horrendous one-line fix I originally proposed!

Perhaps we should update the trainer docatring and initialization to reflect the changes?

@github-actions github-actions bot added the testing Continuous integration testing label Feb 22, 2024
@calebrob6
Copy link
Member Author

Haha no roasting from me (don't throw stones in a glass house and all that 😉)! We appreciate you taking the time to open an issue.

Good point -- we need to remove the UserWarning that we were previously throwing.

@DimitrisMantas
Copy link
Contributor

Also, I've noticed that smp converts classes to a Tensor like so:

if classes is not None:
    assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary" 
    classes = to_tensor(classes, dtype=torch.long)

where:

def to_tensor(x, dtype=None) -> torch.Tensor:
    if isinstance(x, torch.Tensor):
        if dtype is not None:
            x = x.type(dtype)
        return x
    if isinstance(x, np.ndarray):
        x = torch.from_numpy(x)
        if dtype is not None:
            x = x.type(dtype)
        return x
    if isinstance(x, (list, tuple)):
        x = np.array(x)
        x = torch.from_numpy(x)
        if dtype is not None:
            x = x.type(dtype)
        return x

See https://github.com/qubvel/segmentation_models.pytorch/blob/6db76a1106426ac5b55f39fba68168f3bccae7f8/segmentation_models_pytorch/losses/jaccard.py#L45 and https://github.com/qubvel/segmentation_models.pytorch/blob/6db76a1106426ac5b55f39fba68168f3bccae7f8/segmentation_models_pytorch/losses/_functional.py#L18.

Since the class set is not a checked instance, x does not get transformed to the format smp expects.

I think it's OK in the end since classes ends up being used to index a loss array, but maybe we should still be good citizens and convert the set to a tuple?

isaaccorley
isaaccorley previously approved these changes Feb 22, 2024
@@ -111,8 +104,12 @@ def configure_losses(self) -> None:
ignore_index=ignore_value, weight=self.hparams["class_weights"]
)
elif loss == "jaccard":
class_set = set(range(self.hparams["num_classes"]))
Copy link
Contributor

@robmarkcole robmarkcole Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This will just include all classes? I've suggested a more flexible approach at #1896 (comment)

I think there is some confusion as class_set is an index of the out_channels/classes, whilst ignore_index is a specific int to ignore (e.g. 0 or 255)

Copy link
Contributor

@DimitrisMantas DimitrisMantas Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But according to the smp docs, classes should be a list of the integer class labels to include in loss calculations.

In this context, I think using this approach is fine; the ignored index is removed from the set if it is not null, as required.

Copy link
Contributor

@robmarkcole robmarkcole Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The docs state: classes – List of classes that contribute in loss computation. By default, all channels are included.
And if we inspect the source code, loss = loss[self.classes] is access the channels defined by self.classes if I am not mistaken

Noting isaaccorley/torchseg#7

Copy link
Member Author

@calebrob6 calebrob6 Feb 22, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

loss in the source code is a C dimensional tensor that gets indexed by whatever you pass into classes. If you have C=4 classes and pass [0,1,2,3] (or None) then this will use all classes, however if you pass [0,2,3] then it will behave as ignore_index=1 would behave in other losses. See last three cells of https://gist.github.com/calebrob6/658edaa59c68f0c0a510f8d9d7a41458

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sets scare me (unsorted, non-deterministic) and the type hints require a list. Let's use:

Suggested change
class_set = set(range(self.hparams["num_classes"]))
classes = list(range(self.hparams["num_classes"]))

@robmarkcole
Copy link
Contributor

As commented I think this approach is incorrect. What is required is for Jaccardloss to correctly implement ignore_index, see how it is done in https://smp.readthedocs.io/en/latest/_modules/segmentation_models_pytorch/losses/focal.html#FocalLoss

@@ -111,8 +104,12 @@ def configure_losses(self) -> None:
ignore_index=ignore_value, weight=self.hparams["class_weights"]
)
elif loss == "jaccard":
class_set = set(range(self.hparams["num_classes"]))
if ignore_index is not None:
class_set.remove(ignore_index)
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@robmarkcole -- if the user sets ignore_index then it is removed from the class_set here and the remaining classes are used in JaccardLoss.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Am I missing something?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

But ignore_index is a pixel value in the mask (e.g. 255) to ignore, unless I am mistaken?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah I think I understand your concern now -- you're saying that the mask won't necessarily have contiguous class values (e.g. NLCD)

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Correct, I've seen nodata as -9999, 0, 255 etc

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm back -- JaccardLoss assumes contiguous class values:

It sets num_classes to the size of the predicted mask's class dimension -- https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/jaccard.py#L69

Then it converts y_true to a 1-hot encoding with this number of classes -- https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/jaccard.py#L80

I confirmed this will throw an error if you have a mask with values like 255 but predictions with only 4 channels.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like ignore_index is the way to go so you would transform the integer mask to some values that you want to ignore in the loss. Too bad JaccardLoss doesn't have this as an input but I'm currently modifying it in torchseg.

This PR is probably okay as a workaround for now though

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@DimitrisMantas, yep totally understand. We're now thinking about the case where you have a mask where the values aren't {0, ..., C-1}. This happens often in practice (for example, NLCD land cover data have values 11, 12, 21, ... for whatever reason -- https://www.mrlc.gov/data/legends/national-land-cover-database-class-legend-and-description). This is probably out of scope for this particular PR as this case isn't support by smp's JaccardLoss anyway.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If ignore_index is a large value like 255, on calling class_set.remove(ignore_index) a KeyError will be raised. Overall I think assuming continuous values will cause some issues and I think it would be better to wait for correct ignore_index handling in the loss. In my own code for now:

        if ignore_index:
            raise Exception("JaccardLoss does not observe ignore_index and one has been configured.")
        return smp.losses.JaccardLoss(mode="multiclass")

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We could use discard instead of remove to handle the case where ignore_index is not included in range(self.hparams["num_classes"]). But I agree that this does not the underlying issue of non-contiguous class labels...

Another way would be to go with the smp.FocalLoss approach:

loss = 0

# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
	not_ignored = y_true != self.ignore_index

for cls in range(num_classes):
	cls_y_true = (y_true == cls).long()
	cls_y_pred = y_pred[:, cls, ...]

	if self.ignore_index is not None:
		cls_y_true = cls_y_true[not_ignored]
		cls_y_pred = cls_y_pred[not_ignored]

	loss += self.focal_loss_fn(cls_y_pred, cls_y_true)

However, there is a slight difference between this and the proposed approach. Namely, I believe class_set is effectively an index to the output channels, whereas not_ignored specifies certain target classes to ignore. Here is a relevant test.

@isaaccorley
Copy link
Collaborator

I think it's just missing a test when ignore_index is not None

@calebrob6
Copy link
Member Author

Okay this is working up to the point that JaccardLoss supports.

With CrossEntropyLoss you can do something like this:

y_pred = torch.rand(1, 4, 256, 256)
y_pred = nn.functional.softmax(y_pred, dim=1)

y_true = torch.randint(0, 4, size=(1, 256, 256))
y_true[0,0,0] = 255

loss =nn.CrossEntropyLoss(ignore_index=255)
loss(y_pred, y_true)

and it will work as expected (although it will error if you don't ignore index here and it will error if you have values like 254 and 255 as you can only ignore one of them).

This is not supported by JaccardLoss as it assumes the class values range from 0 to C-1 where C is the number of classes.

I would recommend we merge this, then open a new issue that combines some of the points from above and from Robin's issue:

  • Users might want to ignore multiple classes
  • Users might want to ignore a single non-contiguous class value with JaccardLoss
  • Users might want to use masks that have non-contiguous class values (although I think just remapping to [0, C-1] should be on the users)

@adamjstewart
Copy link
Collaborator

Users might want to ignore multiple classes

Why not use classes instead of ignore_index for all loss functions? Seems like they all support it. This would offer maximum flexibility.

Users might want to ignore a single non-contiguous class value with JaccardLoss
Users might want to use masks that have non-contiguous class values

This isn't really how softmax works. If you want to be able to predict a class like 255, you have to have 256+ output classes.

(although I think just remapping to [0, C-1] should be on the users)

I think this is the answer to the contiguous problem. ML simply doesn't support non-contiguous output classes. Our NLCD/CDL datasets already map to contiguous values.

@calebrob6
Copy link
Member Author

To do this with CrossEntropyLoss you would have to set the class weights to the ignored classes to 0 (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). This seems excessive.

@adamjstewart
Copy link
Collaborator

adamjstewart commented Feb 25, 2024

Unless we use https://smp.readthedocs.io/en/latest/losses.html#softcrossentropyloss instead.

EDIT: Nvm, that doesn't support classes or ignore_index...

@DimitrisMantas
Copy link
Contributor

DimitrisMantas commented Feb 25, 2024

Would it perhaps be an idea to create our own versions of these losses in the future using torchmetrics, which is actively maintained and tested?

@robmarkcole
Copy link
Contributor

Note that losses need to be differetiable, and not all torchmetrics are. When I asked the maintainers about this specifically they did not recommend it

@adamjstewart
Copy link
Collaborator

adamjstewart commented Feb 25, 2024

Wait, why aren't we using torchmetrics? Do they not support semantic segmentation?

EDIT: Ah, because these are losses, not metrics.

@isaaccorley
Copy link
Collaborator

isaaccorley commented Feb 25, 2024

I'll be updating the losses in torchseg to support both ignore_index and classes. @adamjstewart You may want to take a look at the library so we can potentially transition over to it sooner rather than later.

Copy link
Collaborator

@adamjstewart adamjstewart left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Wondering if we should backport this to 0.5.2. Although we were passing an int instead of list[int], it seems like it was still working somehow? So it's just a new feature, not a bug fix?

@calebrob6 calebrob6 merged commit c828a7d into main Feb 25, 2024
24 checks passed
@calebrob6 calebrob6 deleted the ignore_index_fix branch February 25, 2024 22:14
@adamjstewart adamjstewart added this to the 0.6.0 milestone Feb 29, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
testing Continuous integration testing trainers PyTorch Lightning trainers
Projects
None yet
5 participants