diff --git a/docs/source/pages/overview.rst b/docs/source/pages/overview.rst index 6fa6aeb256b..649f755b366 100644 --- a/docs/source/pages/overview.rst +++ b/docs/source/pages/overview.rst @@ -361,6 +361,50 @@ information on this topic. .. autoclass:: torchmetrics.MetricCollection :exclude-members: update, compute, forward +*************** +Metric wrappers +*************** + +In some cases it is beneficial to transform the output of one metric in some way or add additional logic. For this we +have implemented a few *Wrapper* metrics. Wrapper metrics always take another :class:`~torchmetrics.Metric` or ( +:class:`~torchmetrics.MetricCollection`) as input and wraps it in some way. A good example of this is the +:class:`~torchmetrics.wrappers.ClasswiseWrapper` that allows for easy altering the output of certain classification +metrics to also include label information. + +.. testcode:: + from torchmetrics.classification import MulticlassAccuracy + from torchmetrics.wrappers import ClasswiseWrapper + base_metric = MulticlassAccuracy(num_classes=3, average=None) + wrapped_metric = ClasswiseWrapper(base_metric, labels=["cat", "dog", "fish"]) + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + print(base_metric(preds, target)) # this returns a simple tensor without label info + print(wrapped_metric(preds, target)) # this returns a dict with label info + +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + torch.tensor([0.0000, 0.0000, 0.3333] + {'multiclassaccuracy_cat': tensor(0.), + 'multiclassaccuracy_dog': tensor(0.), + 'multiclassaccuracy_fish': tensor(0.3333)} + +Another good example of wrappers is the :class:`~torchmetrics.wrappers.BootStrapper` that allows for easy bootstrapping +of metrics e.g. computation of confidence intervals by resampling of input data. + +.. testcode:: + from torchmetrics.classification import MulticlassAccuracy + from torchmetrics.wrappers import ClasswiseWrapper + wrapped_metric = BootStrapper(MulticlassAccuracy(num_classes=3)) + target = torch.tensor([0, 2, 0, 2, 0, 1, 0, 2]) + preds = torch.tensor([2, 1, 2, 0, 1, 2, 2, 2]) + +.. testoutput:: + :options: +NORMALIZE_WHITESPACE + + {'mean': tensor(0.1219), 'std': tensor(0.0217)} + +You can see all implemented wrappers under the wrapper section of the API docs. **************************** Module vs Functional Metrics