You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am computing Accuracy and F1Score for a multilabel classification problem. If my target tensors are ints, I get an error from torch:
File "/path/venv/lib/python3.9/site-packages/torch/nn/functional.py", line 3150, in binary_cross_entropy_with_logits
return torch.binary_cross_entropy_with_logits(input, target, weight, pos_weight, reduction_enum)
RuntimeError: result type Float can't be cast to the desired output type Long
If my target tensors are floats, I get a torchmetrics error:
File "/path/venv/lib/python3.9/site-packages/torchmetrics/utilities/checks.py", line 47, in _basic_input_validation
raise ValueError("The `target` has to be an integer tensor.")
ValueError: The `target` has to be an integer tensor.
I think this is a bug, and torchmetrics should not check that targets are ints.
To Reproduce
You can use a transformers sequence classification model set for multilabel classification,
Hi @samhavens, thanks for reporting this issue.
In this case I would not really call it a bug but more an design decision. The core difference between torch.binary_cross_entropy_with_logits and a metric like F1Score is that the underlying calculation in torch.binary_cross_entropy_with_logits supports values in the [0,1] domain for target and F1Score needs target to be {0,1}, else the calculation will lead to errors. Essentially, this discussion then boils down to if the statement
target=target.long()
should be something the user does or it is something that we should do internally. I think that there are pros and cons for doing both.
It is already in the documentation, example:
but maybe we can make it more explicit. I will add it as an todo for the classification metrics in the ongoing update of the documentation (issue #1365).
Closing this issue then.
🐛 Bug
I am computing Accuracy and F1Score for a multilabel classification problem. If my target tensors are ints, I get an error from torch:
If my target tensors are floats, I get a torchmetrics error:
I think this is a bug, and torchmetrics should not check that targets are ints.
To Reproduce
You can use a
transformers
sequence classification model set for multilabel classification,roberta-large
, for example, usesBCEWithLogits
as the loss function, so it will surface this.Expected behavior
torchmetrics and torch agree on types
Environment
conda
,pip
, build from source): 0.9.3, pip installedThe text was updated successfully, but these errors were encountered: