Skip to content

Commit

Permalink
Improvements to docs on custom implementations (Lightning-AI#2061)
Browse files Browse the repository at this point in the history
Co-authored-by: Jirka Borovec <[email protected]>
Co-authored-by: mergify[bot] <37929162+mergify[bot]@users.noreply.github.com>
  • Loading branch information
3 people authored Sep 15, 2023
1 parent 6eb874c commit b7b118a
Show file tree
Hide file tree
Showing 4 changed files with 191 additions and 39 deletions.
160 changes: 127 additions & 33 deletions docs/source/pages/implement.rst
Original file line number Diff line number Diff line change
Expand Up @@ -5,9 +5,19 @@
from typing import Optional, Sequence, Union
from torch import Tensor

*********************
#####################
Implementing a Metric
*********************
#####################

While we strive to include as many metrics as possible in ``torchmetrics``, we cannot include them all. Therefore, we
have made it easy to implement your own metric and possible contribute it to ``torchmetrics``. This page will guide
you through the process. If you afterwards are interested in contributing your metric to ``torchmetrics``, please
read the `contribution guidelines <https://torchmetrics.readthedocs.io/en/latest/generated/CONTRIBUTING.html>`_ and
see this :ref:`section <contributing metric>`.

**************
Base interface
**************

To implement your own custom metric, subclass the base :class:`~torchmetrics.Metric` class and implement the following
methods:
Expand All @@ -17,36 +27,110 @@ methods:
- ``compute()``: Computes a final value from the state of the metric.

We provide the remaining interface, such as ``reset()`` that will make sure to correctly reset all metric
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself.
Additionally, adding metric states with ``add_state`` will make sure that states are correctly synchronized
in distributed settings (DDP). To see how metric states are synchronized across distributed processes,
refer to ``add_state()`` docs from the base ``Metric`` class.
states that have been added using ``add_state``. You should therefore not implement ``reset()`` yourself, only in rare
cases where not all the state variables should be reset to their default value. Adding metric states with ``add_state``
will make sure that states are correctly synchronized in distributed settings (DDP). To see how metric states are
synchronized across distributed processes, refer to :meth:`~torchmetrics.Metric.add_state()` docs from the base
:class:`~torchmetrics.Metric` class.

Example implementation:
Below is a basic implementation of a custom accuracy metric. In the ``__init__`` method we add the metric states
``correct`` and ``total``, which will be used to accumulate the number of correct predictions and the total number
of predictions, respectively. In the ``update`` method we update the metric states based on the inputs to the metric.
Finally, in the ``compute`` method we compute the final metric value based on the metric states.

.. testcode::

from torchmetrics import Metric

class MyAccuracy(Metric):
def __init__(self):
super().__init__()
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("correct", default=torch.tensor(0), dist_reduce_fx="sum")
self.add_state("total", default=torch.tensor(0), dist_reduce_fx="sum")
def update(self, preds: Tensor, target: Tensor):
def update(self, preds: Tensor, target: Tensor) -> None:
preds, target = self._input_format(preds, target)
assert preds.shape == target.shape
if preds.shape != target.shape:
raise ValueError("preds and target must have the same shape")

self.correct += torch.sum(preds == target)
self.total += target.numel()

def compute(self):
def compute(self) -> Tensor:
return self.correct.float() / self.total

A few important things to note:

* The ``dist_reduce_fx`` argument to ``add_state`` is used to specify how the metric states should be reduced between
batches in distributed settings. In this case we use ``"sum"`` to sum the metric states across batches. A couple of
build in options are available: ``"sum"``, ``"mean"``, ``"cat"``, ``"min"`` or ``"max"``, but a custom reduction is
also supported.

* In ``update`` we do not return anything but instead update the metric states in-place.

* In ``compute`` when running in distributed mode, the states would have been synced before the compute method is
called. Thus ``self.correct`` and ``self.total`` will contain the sum of the metric states across all processes.

************************
Working with list states
************************

When initializing metric states with ``add_state``, the ``default`` argument can either be a single tensor (as in the
example above) or an empty list. Most metric will only require a single tensor to accumulate the metric states, but
for some metrics that need access to the individual batch states, it can be useful to use a list of tensors. In the
following example we show how to implement Spearman correlation, which requires access to the individual batch states
because we need to calculate the rank of the predictions and targets.

.. testcode::

from torchmetrics import Metric
from torchmetrics.utilities import dim_zero_cat

Additionally you may want to set the class properties: `is_differentiable`, `higher_is_better` and
`full_state_update`. Note that none of them are strictly required for the metric to work.
class MySpearmanCorrCoef(Metric):
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.add_state("preds", default=[], dist_reduce_fx="cat")
self.add_state("target", default=[], dist_reduce_fx="cat")
def update(self, preds: Tensor, target: Tensor) -> None:
self.preds.append(preds)
self.target.append(target)

def compute(self):
# parse inputs
preds = dim_zero_cat(preds)
target = dim_zero_cat(target)
# some intermediate computation...
r_preds, r_target = _rank_data(preds), _rank_dat(target)
preds_diff = r_preds - r_preds.mean(0)
target_diff = r_target - r_target.mean(0)
cov = (preds_diff * target_diff).mean(0)
preds_std = torch.sqrt((preds_diff * preds_diff).mean(0))
target_std = torch.sqrt((target_diff * target_diff).mean(0))
# finalize the computations
corrcoef = cov / (preds_std * target_std + eps)
return torch.clamp(corrcoef, -1.0, 1.0)

A few important things to note for this example:

* When working with list states, the ``dist_reduce_fx`` argument to ``add_state`` should be set to ``"cat"`` to
concatenate the list of tensors across batches.

* When working with list states, The ``update(...)`` method should append the batch states to the list.

* In the the ``compute`` method the list states behave a bit differently dependeding on weather you are running in
distributed mode or not. In non-distributed mode the list states will be a list of tensors, while in distributed mode
the list have already been concatenated into a single tensor. For this reason, we recommend always using the
``dim_zero_cat`` helper function which will standardize the list states to be a single concatenate tensor regardless
of the mode.

*****************
Metric attributes
*****************

When done implementing your own metric, there are a few properties and attributes that you may want to set to add
additional functionality. The three attributes to consider are: ``is_differentiable``, ``higher_is_better`` and
``full_state_update``. Note that none of them are strictly required to be set for the metric to work.

.. testcode::

Expand All @@ -65,8 +149,12 @@ Additionally you may want to set the class properties: `is_differentiable`, `hig
# batch states are independent and we will optimize the runtime of 'forward'
full_state_update: bool = True

Finally, from torchmetrics v1.0.0 onwards, we also support plotting of metrics through the `.plot` method. By default
this method will raise `NotImplementedError` but can be implemented by the user to provide a custom plot for the metric.
**************
Plot interface
**************

From torchmetrics v1.0.0 onwards, we also support plotting of metrics through the ``.plot()`` method. By default this method
will raise `NotImplementedError` but can be implemented by the user to provide a custom plot for the metric.
For any metrics that returns a simple scalar tensor, or a dict of scalar tensors the internal `._plot` method can be
used, that provides the common plotting functionality for most metrics in torchmetrics.

Expand All @@ -76,18 +164,22 @@ used, that provides the common plotting functionality for most metrics in torchm
from torchmetrics.utilities.plot import _AX_TYPE, _PLOT_OUT_TYPE

class MyMetric(Metric):
...
# set these attributes if you want to use the internal ._plot method
# bounds are automatically added to the generated plot
plot_lower_bound: Optional[float] = None
plot_upper_bound: Optional[float] = None

def plot(
self, val: Optional[Union[Tensor, Sequence[Tensor]]] = None, ax: Optional[_AX_TYPE] = None
) -> _PLOT_OUT_TYPE:
return self._plot(val, ax)

If the metric returns a more complex output, a custom implementation of the `plot` method is required. For more details
on the plotting API, see the this :ref:`page <plotting>` .
on the plotting API, see the this :ref:`page <plotting>` . In addti

*******************************
Internal implementation details
-------------------------------
*******************************

This section briefly describes how metrics work internally. We encourage looking at the source code for more info.
Internally, TorchMetrics wraps the user defined ``update()`` and ``compute()`` method. We do this to automatically
Expand Down Expand Up @@ -123,8 +215,8 @@ can behave in two ways:
5. Calls ``compute()`` to calculate metric for current batch.
6. Restores the global state.

2. If ``full_state_update`` is ``False`` (default) the metric state of one batch is completly independent of the state of
other batches, which means that we only need to call ``update`` once.
2. If ``full_state_update`` is ``False`` (default) the metric state of one batch is completly independent of the state
of other batches, which means that we only need to call ``update`` once.

1. Caches the global state.
2. Calls ``reset`` the metric to its default state
Expand All @@ -133,25 +225,27 @@ can behave in two ways:
5. Reduce the global state and batch state into a single state that becomes the new global state

If implementing your own metric, we recommend trying out the metric with ``full_state_update`` class property set to
both ``True`` and ``False``. If the results are equal, then setting it to ``False`` will usually give the best performance.

---------
both ``True`` and ``False``. If the results are equal, then setting it to ``False`` will usually give the best
performance.

.. autoclass:: torchmetrics.Metric
:noindex:
:members:

.. _contributing metric:

****************************************
Contributing your metric to TorchMetrics
----------------------------------------
****************************************

Wanting to contribute the metric you have implemented? Great, we are always open to adding more metrics to ``torchmetrics``
as long as they serve a general purpose. However, to keep all our metrics consistent we request that the implementation
and tests gets formatted in the following way:

1. Start by reading our `contribution guidelines <https://torchmetrics.readthedocs.io/en/latest/generated/CONTRIBUTING.html>`_.
2. First implement the functional backend. This takes cares of all the logic that goes into the metric. The code should
be put into a single file placed under ``torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of
metric (classification, regression, nlp etc) and ``new_metric`` is the name of the metric. In this file, there should be the
be put into a single file placed under ``src/torchmetrics/functional/"domain"/"new_metric".py`` where ``domain`` is the type of
metric (classification, regression, text etc.) and ``new_metric`` is the name of the metric. In this file, there should be the
following three functions:

1. ``_new_metric_update(...)``: everything that has to do with type/shape checking and all logic required before distributed syncing need to go here.
Expand All @@ -160,10 +254,10 @@ and tests gets formatted in the following way:
makes up the functional interface for the metric.

.. note::
The `functional accuracy <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/classification/accuracy.py>`_
The `functional mean squared error <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/functional/regression/mse.py>`_
metric is a great example of this division of logic.

3. In a corresponding file placed in ``torchmetrics/"domain"/"new_metric".py`` create the module interface:
3. In a corresponding file placed in ``src/torchmetrics/"domain"/"new_metric".py`` create the module interface:

1. Create a new module metric by subclassing ``torchmetrics.Metric``.
2. In the ``__init__`` of the module call ``self.add_state`` for as many metric states are needed for the metric to
Expand All @@ -173,15 +267,15 @@ and tests gets formatted in the following way:
We do this to not have duplicate code to maintain.

.. note::
The module `Accuracy <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/classification/accuracy.py>`_
The module `MeanSquaredError <https://github.com/Lightning-AI/torchmetrics/blob/master/src/torchmetrics/regression/mse.py>`_
metric that corresponds to the above functional example showcases these steps.

4. Remember to add binding to the different relevant ``__init__`` files.

5. Testing is key to keeping ``torchmetrics`` trustworthy. This is why we have a very rigid testing protocol. This means
that we in most cases require the metric to be tested against some other common framework (``sklearn``, ``scipy`` etc).

1. Create a testing file in ``unittests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test
1. Create a testing file in ``tests/unittests/"domain"/test_"new_metric".py``. Only one file is needed as it is intended to test
both the functional and module interface.
2. In that file, start by defining a number of test inputs that your metric should be evaluated on.
3. Create a testclass ``class NewMetric(MetricTester)`` that inherits from ``tests.helpers.testers.MetricTester``.
Expand All @@ -194,8 +288,8 @@ and tests gets formatted in the following way:
5. (optional) If your metric raises any exception, please add tests that showcase this.

.. note::
The `test file for accuracy <https://github.com/Lightning-AI/torchmetrics/blob/master/tests/unittests/classification/test_accuracy.py>`_ metric
shows how to implement such tests.
The `test file for MSE <https://github.com/Lightning-AI/torchmetrics/blob/master/tests/unittests/regression/test_mean_error.py>`_
metric shows how to implement such tests.

If you only can figure out part of the steps, do not fear to send a PR. We will much rather receive working
metrics that are not formatted exactly like our codebase, than not receiving any. Formatting can always be applied.
Expand Down
55 changes: 51 additions & 4 deletions docs/source/references/utilities.rst
Original file line number Diff line number Diff line change
@@ -1,9 +1,19 @@
.. role:: hidden
:class: hidden-section

###########################
######################
torchmetrics.utilities
######################

In the following is listed public utility functions that may be beneficial to use in your own code. These functions are
not part of the public API and may change at any time.

***************************
torchmetrics.utilities.data
###########################
***************************

The `data` utilities are used to help with data manipulation, such as converting labels in classification from one
format to another.

select_topk
~~~~~~~~~~~
Expand All @@ -20,9 +30,46 @@ to_onehot

.. autofunction:: torchmetrics.utilities.data.to_onehot

#################################
dim_zero_cat
~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.dim_zero_cat

dim_zero_max
~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.dim_zero_max

dim_zero_mean
~~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.dim_zero_mean

dim_zero_min
~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.dim_zero_min

dim_zero_sum
~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.data.dim_zero_sum

**********************************
torchmetrics.utilities.distributed
**********************************

The `distributed` utilities are used to help with syncronization of metrics across multiple processes.

gather_all_tensors
~~~~~~~~~~~~~~~~~~

.. autofunction:: torchmetrics.utilities.distributed.gather_all_tensors
:noindex:

*********************************
torchmetrics.utilities.exceptions
#################################
*********************************

TorchMetricsUserError
~~~~~~~~~~~~~~~~~~~~~
Expand Down
12 changes: 12 additions & 0 deletions src/torchmetrics/utilities/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,13 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from torchmetrics.utilities.checks import check_forward_full_state_property
from torchmetrics.utilities.data import (
dim_zero_cat,
dim_zero_max,
dim_zero_mean,
dim_zero_min,
dim_zero_sum,
)
from torchmetrics.utilities.distributed import class_reduce, reduce
from torchmetrics.utilities.prints import rank_zero_debug, rank_zero_info, rank_zero_warn

Expand All @@ -22,4 +29,9 @@
"rank_zero_debug",
"rank_zero_info",
"rank_zero_warn",
"dim_zero_cat",
"dim_zero_max",
"dim_zero_mean",
"dim_zero_min",
"dim_zero_sum",
]
3 changes: 1 addition & 2 deletions src/torchmetrics/utilities/distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -105,8 +105,7 @@ def gather_all_tensors(result: Tensor, group: Optional[Any] = None) -> List[Tens
group: the process group to gather results from. Defaults to all processes (world)
Return:
gathered_result: list with size equal to the process group where
``gathered_result[i]`` corresponds to result tensor from process ``i``
list with size equal to the process group where element i corresponds to result tensor from process i
"""
if group is None:
Expand Down

0 comments on commit b7b118a

Please sign in to comment.