Large changes to classifications #1248
-
TorchMetrics v0.10 is now out, significantly changing the whole classification package. This blog post will go over the reasons why the classification package needs to be refactored, what it means for our end users, and finally, what benefits it gives. A guide on how to upgrade your code to the recent changes can be found near the bottom. Why the classification metrics need to changeWe have for a long time known that there were some underlying problems with how we initially structured the classification package. Essentially, classification tasks can e divided into either binary, multiclass, or multilabel, and determining what task a user is trying to run a given metric on is hard just based on the input. The reason a package such as sklearn can do this is to only support input in very specific formats (no multi-dimensional arrays and no support for both integer and probability/logit formats). This meant that some metrics, especially for binary tasks, could have been calculating something different than expected if the user were to provide another shape but the expected. This is against the core value of TorchMetrics, that our users, of course should trust that the metric they are evaluating is given the excepted result. Additionally, classification metrics were missing consistency. For some, The solutionThe solution we went with was to split every classification metric into three separate metrics with the prefix
Standardized argumentsThe input arguments for the classification package are now much more standardized. Here are a few examples:
Constant memory implementationsSome of the most useful metrics for evaluating classification problems are metrics such as ROC, AUROC, AveragePrecision, etc., because they not only evaluate your model for a single threshold but a whole range of thresholds, essentially giving you the ability to see the trade-off between Type I and Type II errors. However, a big problem with the standard formulation of these metrics (which we have been using) is that they require access to all data for their calculation. Our implementation has been extremely memory-intensive for these kinds of metrics. In v0.10 of TorchMetrics, all these metrics now have an argument called thresholds. By default, it is Setting This also means that the All metrics are faster (ish)By splitting each metric into 3 separate metrics, we reduce the number of calculations needed. We, therefore, expected out-of-the-box that our new implementations would be faster. The table below shows the timings of different metrics with the old and new implementations (with and without input validation). Numbers in parentheses denote speed-up over old implementations. The following observations can be made:
[0.10.0] - 2022-10-04Added
Changed
Fixed
Contributors@Borda, @bryant1410, @geoffrey-g-delhomme, @justusschock, @lucadiliello, @nicolas-dufour, @Queuecumber, @SkafteNicki, @stancld If we forgot someone due to not matching commit email with GitHub account, let us know :] This discussion was created from the release Large changes to classifications. |
Beta Was this translation helpful? Give feedback.
Replies: 4 comments 7 replies
-
Hi! Quick question: does multilabel mean multiple binary labels? Is it implemented as independent binary metrics? So binary implies single binary metric? Thanks. |
Beta Was this translation helpful? Give feedback.
-
Thanks for this significant update! Is there a link to the upgrade guide? |
Beta Was this translation helpful? Give feedback.
-
Hello, may I ask why the default in MultiClassAccuracy is set to 'macro' instead of 'micro'? Sklearn uses 'micro' statistics too https://scikit-learn.org/stable/modules/generated/sklearn.metrics.accuracy_score.html |
Beta Was this translation helpful? Give feedback.
-
Still precision values are different from sklearn |
Beta Was this translation helpful? Give feedback.
The problem is the input order
In sklearn it should be
metric(truth, preds)
In torchmetrics it should be
metric(preds, truth)
(we use the this order because it is consistent with loss functions in torch)Switching the order for the sklearn calculations in your example fixes it.