-
Notifications
You must be signed in to change notification settings - Fork 333
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Correct behavior of ignore_index for JaccardLoss #1898
Conversation
Hi, I'm the author of #1891. Thanks for taking the time to fix the issue, and, most importantly, not roast the horrendous one-line fix I originally proposed! Perhaps we should update the trainer docatring and initialization to reflect the changes? |
Haha no roasting from me (don't throw stones in a glass house and all that 😉)! We appreciate you taking the time to open an issue. Good point -- we need to remove the UserWarning that we were previously throwing. |
Also, I've noticed that smp converts if classes is not None:
assert mode != BINARY_MODE, "Masking classes is not supported with mode=binary"
classes = to_tensor(classes, dtype=torch.long) where: def to_tensor(x, dtype=None) -> torch.Tensor:
if isinstance(x, torch.Tensor):
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, np.ndarray):
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x
if isinstance(x, (list, tuple)):
x = np.array(x)
x = torch.from_numpy(x)
if dtype is not None:
x = x.type(dtype)
return x See https://github.com/qubvel/segmentation_models.pytorch/blob/6db76a1106426ac5b55f39fba68168f3bccae7f8/segmentation_models_pytorch/losses/jaccard.py#L45 and https://github.com/qubvel/segmentation_models.pytorch/blob/6db76a1106426ac5b55f39fba68168f3bccae7f8/segmentation_models_pytorch/losses/_functional.py#L18. Since the class set is not a checked instance, I think it's OK in the end since |
torchgeo/trainers/segmentation.py
Outdated
@@ -111,8 +104,12 @@ def configure_losses(self) -> None: | |||
ignore_index=ignore_value, weight=self.hparams["class_weights"] | |||
) | |||
elif loss == "jaccard": | |||
class_set = set(range(self.hparams["num_classes"])) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This will just include all classes? I've suggested a more flexible approach at #1896 (comment)
I think there is some confusion as class_set
is an index of the out_channels/classes, whilst ignore_index is a specific int
to ignore (e.g. 0 or 255)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But according to the smp docs, classes
should be a list of the integer class labels to include in loss calculations.
In this context, I think using this approach is fine; the ignored index is removed from the set if it is not null, as required.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
The docs state: classes – List of classes that contribute in loss computation. By default, all channels are included.
And if we inspect the source code, loss = loss[self.classes]
is access the channels defined by self.classes
if I am not mistaken
Noting isaaccorley/torchseg#7
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
loss
in the source code is a C dimensional tensor that gets indexed by whatever you pass into classes
. If you have C=4 classes and pass [0,1,2,3]
(or None
) then this will use all classes, however if you pass [0,2,3] then it will behave as ignore_index=1
would behave in other losses. See last three cells of https://gist.github.com/calebrob6/658edaa59c68f0c0a510f8d9d7a41458
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sets scare me (unsorted, non-deterministic) and the type hints require a list. Let's use:
class_set = set(range(self.hparams["num_classes"])) | |
classes = list(range(self.hparams["num_classes"])) |
As commented I think this approach is incorrect. What is required is for Jaccardloss to correctly implement ignore_index, see how it is done in https://smp.readthedocs.io/en/latest/_modules/segmentation_models_pytorch/losses/focal.html#FocalLoss |
torchgeo/trainers/segmentation.py
Outdated
@@ -111,8 +104,12 @@ def configure_losses(self) -> None: | |||
ignore_index=ignore_value, weight=self.hparams["class_weights"] | |||
) | |||
elif loss == "jaccard": | |||
class_set = set(range(self.hparams["num_classes"])) | |||
if ignore_index is not None: | |||
class_set.remove(ignore_index) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@robmarkcole -- if the user sets ignore_index
then it is removed from the class_set
here and the remaining classes are used in JaccardLoss.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Am I missing something?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
But ignore_index
is a pixel value in the mask (e.g. 255) to ignore, unless I am mistaken?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah I think I understand your concern now -- you're saying that the mask won't necessarily have contiguous class values (e.g. NLCD)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Correct, I've seen nodata as -9999, 0, 255 etc
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I'm back -- JaccardLoss
assumes contiguous class values:
It sets num_classes
to the size of the predicted mask's class dimension -- https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/jaccard.py#L69
Then it converts y_true
to a 1-hot encoding with this number of classes -- https://github.com/qubvel/segmentation_models.pytorch/blob/master/segmentation_models_pytorch/losses/jaccard.py#L80
I confirmed this will throw an error if you have a mask with values like 255 but predictions with only 4 channels.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Seems like ignore_index is the way to go so you would transform the integer mask to some values that you want to ignore in the loss. Too bad JaccardLoss doesn't have this as an input but I'm currently modifying it in torchseg.
This PR is probably okay as a workaround for now though
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@DimitrisMantas, yep totally understand. We're now thinking about the case where you have a mask where the values aren't {0, ..., C-1}. This happens often in practice (for example, NLCD land cover data have values 11, 12, 21, ... for whatever reason -- https://www.mrlc.gov/data/legends/national-land-cover-database-class-legend-and-description). This is probably out of scope for this particular PR as this case isn't support by smp's JaccardLoss anyway.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
If ignore_index
is a large value like 255, on calling class_set.remove(ignore_index)
a KeyError will be raised. Overall I think assuming continuous values will cause some issues and I think it would be better to wait for correct ignore_index handling in the loss. In my own code for now:
if ignore_index:
raise Exception("JaccardLoss does not observe ignore_index and one has been configured.")
return smp.losses.JaccardLoss(mode="multiclass")
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We could use discard
instead of remove
to handle the case where ignore_index
is not included in range(self.hparams["num_classes"])
. But I agree that this does not the underlying issue of non-contiguous class labels...
Another way would be to go with the smp.FocalLoss
approach:
loss = 0
# Filter anchors with -1 label from loss computation
if self.ignore_index is not None:
not_ignored = y_true != self.ignore_index
for cls in range(num_classes):
cls_y_true = (y_true == cls).long()
cls_y_pred = y_pred[:, cls, ...]
if self.ignore_index is not None:
cls_y_true = cls_y_true[not_ignored]
cls_y_pred = cls_y_pred[not_ignored]
loss += self.focal_loss_fn(cls_y_pred, cls_y_true)
However, there is a slight difference between this and the proposed approach. Namely, I believe class_set
is effectively an index to the output channels, whereas not_ignored
specifies certain target classes to ignore. Here is a relevant test.
I think it's just missing a test when ignore_index is not None |
Okay this is working up to the point that JaccardLoss supports. With CrossEntropyLoss you can do something like this:
and it will work as expected (although it will error if you don't ignore index here and it will error if you have values like 254 and 255 as you can only ignore one of them). This is not supported by I would recommend we merge this, then open a new issue that combines some of the points from above and from Robin's issue:
|
Why not use
This isn't really how softmax works. If you want to be able to predict a class like 255, you have to have 256+ output classes.
I think this is the answer to the contiguous problem. ML simply doesn't support non-contiguous output classes. Our NLCD/CDL datasets already map to contiguous values. |
To do this with CrossEntropyLoss you would have to set the class weights to the ignored classes to 0 (https://pytorch.org/docs/stable/generated/torch.nn.CrossEntropyLoss.html). This seems excessive. |
Unless we use https://smp.readthedocs.io/en/latest/losses.html#softcrossentropyloss instead. EDIT: Nvm, that doesn't support |
Would it perhaps be an idea to create our own versions of these losses in the future using |
Note that losses need to be differetiable, and not all torchmetrics are. When I asked the maintainers about this specifically they did not recommend it |
Wait, why aren't we using torchmetrics? Do they not support semantic segmentation? EDIT: Ah, because these are losses, not metrics. |
I'll be updating the losses in torchseg to support both ignore_index and classes. @adamjstewart You may want to take a look at the library so we can potentially transition over to it sooner rather than later. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Wondering if we should backport this to 0.5.2. Although we were passing an int instead of list[int], it seems like it was still working somehow? So it's just a new feature, not a bug fix?
3ea459e
to
c1061e1
Compare
Fixes #1896
Fixes #1891