Skip to content

Commit

Permalink
fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
SkafteNicki committed Sep 18, 2023
1 parent 8da841c commit e3cc3c4
Showing 1 changed file with 44 additions and 0 deletions.
44 changes: 44 additions & 0 deletions docs/source/pages/overview.rst
Original file line number Diff line number Diff line change
Expand Up @@ -361,6 +361,50 @@ information on this topic.
.. autoclass:: torchmetrics.MetricCollection
:exclude-members: update, compute, forward

***************
Metric wrappers
***************

In some cases it is beneficial to transform the output of one metric in some way or add additional logic. For this we
have implemented a few *Wrapper* metrics. Wrapper metrics always take another :class:`~torchmetrics.Metric` or (
:class:`~torchmetrics.MetricCollection`) as input and wraps it in some way. A good example of this is the
:class:`~torchmetrics.wrappers.ClasswiseWrapper` that allows for easy altering the output of certain classification
metrics to also include label information.

.. testcode::
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.wrappers import ClasswiseWrapper
base_metric = MulticlassAccuracy(num_classes=3, average=None)
wrapped_metric = ClasswiseWrapper(base_metric, labels=["cat", "dog", "fish"])
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])
print(base_metric(preds, target)) # this returns a simple tensor without label info
print(wrapped_metric(preds, target)) # this returns a dict with label info

.. testoutput::
:options: +NORMALIZE_WHITESPACE

torch.tensor([0.0000, 0.0000, 0.3333]
{'multiclassaccuracy_cat': tensor(0.),
'multiclassaccuracy_dog': tensor(0.),
'multiclassaccuracy_fish': tensor(0.3333)}

Another good example of wrappers is the :class:`~torchmetrics.wrappers.BootStrapper` that allows for easy bootstrapping
of metrics e.g. computation of confidence intervals by resampling of input data.

.. testcode::
from torchmetrics.classification import MulticlassAccuracy
from torchmetrics.wrappers import ClasswiseWrapper
wrapped_metric = BootStrapper(MulticlassAccuracy(num_classes=3))
target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2])
preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2])

.. testoutput::
:options: +NORMALIZE_WHITESPACE

{'mean': tensor(0.1219), 'std': tensor(0.0217)}

You can see all implemented wrappers under the wrapper section of the API docs.

****************************
Module vs Functional Metrics
Expand Down

0 comments on commit e3cc3c4

Please sign in to comment.