Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Errors in mask2former caused by class_weight #3811

Open
zhuxuanya opened this issue Nov 11, 2024 · 0 comments
Open

Errors in mask2former caused by class_weight #3811

zhuxuanya opened this issue Nov 11, 2024 · 0 comments

Comments

@zhuxuanya
Copy link

Description

When I trained mask2former with a 3-class custom dataset, I changed the original class weight

        loss_cls=dict(
            type='mmdet.CrossEntropyLoss',
            use_sigmoid=False,
            loss_weight=0.5,
            reduction='mean',
            class_weight=[1.0] * num_classes + [0.1]),

to the specified value class_weight=[0.1, 1.0, 1.0] in order to deal with the imbalance of classes. Then when I trained, I found the following error report.

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [1,0,0], thread: [96,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
(similar report)
......
Traceback (most recent call last):
  File "tools/train.py", line 107, in <module>
    main()
  File "tools/train.py", line 103, in main
    runner.train()
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmengine/runner/runner.py", line 1777, in train
    model = self.train_loop.run()  # type: ignore
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmengine/runner/loops.py", line 287, in run
    self.run_iter(data_batch)
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmengine/runner/loops.py", line 311, in run_iter
    outputs = self.runner.model.train_step(
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 114, in train_step
    losses = self._run_forward(data, mode='loss')  # type: ignore
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmengine/model/base_model/base_model.py", line 361, in _run_forward
    results = self(**data, mode=mode)
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/mmsegmentation/mmseg/models/segmentors/base.py", line 94, in forward
    return self.loss(inputs, data_samples)
  File "/home/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 178, in loss
    loss_decode = self._decode_head_forward_train(x, data_samples)
  File "/home/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 139, in _decode_head_forward_train
    loss_decode = self.decode_head.loss(inputs, data_samples,
  File "/home/mmsegmentation/mmseg/models/decode_heads/mask2former_head.py", line 126, in loss
    losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/dense_heads/maskformer_head.py", line 348, in loss_by_feat
    losses_cls, losses_mask, losses_dice = multi_apply(
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/utils/misc.py", line 219, in multi_apply
    return tuple(map(list, zip(*map_results)))
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/dense_heads/mask2former_head.py", line 296, in _loss_by_feat_single
    loss_cls = self.loss_cls(
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1501, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/losses/cross_entropy_loss.py", line 288, in forward
    class_weight = cls_score.new_tensor(
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Error traceback

When I re-trained after doing export CUDA_LAUNCH_BLOCKING=1, I got this error.

  File "/home/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 178, in loss
    loss_decode = self._decode_head_forward_train(x, data_samples)
  File "/home/mmsegmentation/mmseg/models/segmentors/encoder_decoder.py", line 139, in _decode_head_forward_train
    loss_decode = self.decode_head.loss(inputs, data_samples,
  File "/home/mmsegmentation/mmseg/models/decode_heads/mask2former_head.py", line 126, in loss
    losses = self.loss_by_feat(all_cls_scores, all_mask_preds,
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/dense_heads/maskformer_head.py", line 348, in loss_by_feat
    losses_cls, losses_mask, losses_dice = multi_apply(
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/utils/misc.py", line 219, in multi_apply
    return tuple(map(list, zip(*map_results)))
  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/dense_heads/mask2former_head.py", line 300, in _loss_by_feat_single
    avg_factor=class_weight[labels].sum())
RuntimeError: CUDA error: device-side assert triggered
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Then I went to check this line.

  File "/home/conda/envs/mmseg/lib/python3.8/site-packages/mmdet/models/dense_heads/mask2former_head.py", line 300, in _loss_by_feat_single
    avg_factor=class_weight[labels].sum())

I added some print to see in detail.

        # classfication loss
        # shape (batch_size * num_queries, )
        cls_scores = cls_scores.flatten(0, 1)
        labels = labels.flatten(0, 1)
        label_weights = label_weights.flatten(0, 1)

        class_weight = cls_scores.new_tensor(self.class_weight)
        print(f"labels: {labels}")
        max_label = labels.max()
        print(f"max_label: {max_label}")
        print(f"class_weight: {class_weight}")
        print(f"len(class_weight): {len(class_weight)}")
        loss_cls = self.loss_cls(
            cls_scores,
            labels,
            label_weights,
            avg_factor=class_weight[labels].sum())

        num_total_masks = reduce_mean(cls_scores.new_tensor([avg_factor]))
        num_total_masks = max(num_total_masks, 1)

The output was as follows.

labels: tensor([3, 3, 3,  ..., 3, 3, 3], device='cuda:0')
max_label: 3
class_weight: tensor([0.1000, 1.0000, 1.0000], device='cuda:0')
len(class_weight): 3

The length of list is 3, but class_weight[labels] tries to access class_weight[3], which causes an index out-of-bounds error.

Possible reasons

When I thought about it in terms of the number of classes, I realized that in some datasets, such as ./mmseg/datasets/cityscapes.py, there was no such category as background. When the dataset's categories don't contain a background, the weight of the background is implicitly set to 0.1, added at the end of the list. And the code doesn't make a distinction between having background in the category or not. So the list still needs one more item when my dataset's categories contain background.

Bug fix

If you just want to solve the training problem due to class_weight, just make sure it has one more item than num_classes. I haven't looked deeply into which pixel this default item corresponds to, intuitively it could be 255; of course it's possible that my previous guess about this error was just wrong. Feel free to correct me.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant