Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Target tensor check for int in multilabel classification does not agree with torch #1367

Closed
samhavens opened this issue Dec 1, 2022 · 5 comments
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.9.x

Comments

@samhavens
Copy link

🐛 Bug

I am computing Accuracy and F1Score for a multilabel classification problem. If my target tensors are ints, I get an error from torch:

  File "/path/venv/lib/python3.9/site-packages/torch/nn/functional.py", line 3150, in binary_cross_entropy_with_logits
    return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: result type Float can't be cast to the desired output type Long

If my target tensors are floats, I get a torchmetrics error:

File "/path/venv/lib/python3.9/site-packages/torchmetrics/utilities/checks.py", line 47, in _basic_input_validation
    raise ValueError("The `target` has to be an integer tensor.")
ValueError: The `target` has to be an integer tensor.

I think this is a bug, and torchmetrics should not check that targets are ints.

To Reproduce

You can use a transformers sequence classification model set for multilabel classification,

        hf_config  = AutoConfig.from_pretrained(
            model_name,
            num_labels=[fill in],
            problem_type="multi_label_classification"
        )
        model = AutoModelForSequenceClassification.from_pretrained(
            model_name,
            config=hf_config,
        )

roberta-large, for example, uses BCEWithLogits as the loss function, so it will surface this.

Expected behavior

torchmetrics and torch agree on types

Environment

  • TorchMetrics version (and how you installed TM, e.g. conda, pip, build from source): 0.9.3, pip installed
  • Python & PyTorch Version (e.g., 1.0): torch==1.12.1+cu116, Python 3.9.15
  • Any other relevant information such as OS (e.g., Linux): Linux
@samhavens samhavens added bug / fix Something isn't working help wanted Extra attention is needed labels Dec 1, 2022
@github-actions
Copy link

github-actions bot commented Dec 1, 2022

Hi! thanks for your contribution!, great first issue!

@Borda
Copy link
Member

Borda commented Dec 1, 2022

@stancld mind having a look at it? 🦦

@SkafteNicki
Copy link
Member

Hi @samhavens, thanks for reporting this issue.
In this case I would not really call it a bug but more an design decision. The core difference between torch.binary_cross_entropy_with_logits and a metric like F1Score is that the underlying calculation in torch.binary_cross_entropy_with_logits supports values in the [0,1] domain for target and F1Score needs target to be {0,1}, else the calculation will lead to errors. Essentially, this discussion then boils down to if the statement

target = target.long()

should be something the user does or it is something that we should do internally. I think that there are pros and cons for doing both.

@samhavens
Copy link
Author

Totally reasonable! Might be worth calling out in the documentation though?

@SkafteNicki
Copy link
Member

It is already in the documentation, example:
image
but maybe we can make it more explicit. I will add it as an todo for the classification metrics in the ongoing update of the documentation (issue #1365).
Closing this issue then.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug / fix Something isn't working help wanted Extra attention is needed v0.9.x
Projects
None yet
Development

No branches or pull requests

3 participants