Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make metrics serializable #3001

Merged
8 changes: 5 additions & 3 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Resuming the training
It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour.
Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods
to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler etc user can
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can
store the trainer and then resume the training. For example:

.. code-block:: python
Expand All @@ -82,8 +82,9 @@ store the trainer and then resume the training. For example:
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
Copy link
Collaborator

@vfdev-5 vfdev-5 Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we would like to get that in a more automatic way. Also if we want to save multiple metrics adding them one by one to save/load can be a pain.
Let's propose something in a follow-up PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reiterate what you mean? If we have multiple metrics, is it cumbersome to do:

to_save = {..., 'metric1': metric1, 'metric2': metric2, 'metric3': metric3}

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is too verbose, let's try to find a way to simplify that.

handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
Expand All @@ -104,8 +105,9 @@ We can then restore the training from the last checkpoint.
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.
"""Method replaces internal state of the class with provided state dict data.

Args:
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VariableAccumulation(Metric):
"""

required_output_keys = None
_state_dict_all_req_keys = ("accumulator", "num_examples")

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -208,6 +208,8 @@ def thresholded_output_transform(output):
0.6666...
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")

def __init__(
self,
output_transform: Callable = lambda x: x,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,8 @@ def binary_one_hot_output_transform(output):
[1, 1]])
"""

_state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

def __init__(
self,
num_classes: int,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/epoch_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,6 +67,8 @@ def mse_fn(y_preds, y_targets):
To disable the warning, set ``check_compute_fn=False``.
"""

_state_dict_all_req_keys = ("_predictions", "_targets")

def __init__(
self,
compute_fn: Callable[[torch.Tensor, torch.Tensor], float],
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/gan/fid.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,6 +163,8 @@ def forward(self, x):
.. versionadded:: 0.4.6
"""

_state_dict_all_req_keys = ("_num_examples", "_train_total", "_test_total", "_train_sigma", "_test_sigma")

def __init__(
self,
num_features: Optional[int] = None,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/gan/inception_score.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,8 @@ class InceptionScore(_BaseInceptionMetric):
.. versionadded:: 0.4.6
"""

_state_dict_all_req_keys = ("_num_examples", "_prob_total", "_total_kl_d")

def __init__(
self,
num_features: Optional[int] = None,
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,6 +65,7 @@ class Loss(Metric):
"""

required_output_keys = ("y_pred", "y", "criterion_kwargs")
_state_dict_all_req_keys = ("_sum", "_num_examples")

def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_absolute_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanAbsoluteError(Metric):
2.9375
"""

_state_dict_all_req_keys = ("_sum_of_absolute_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_absolute_errors = torch.tensor(0.0, device=self._device)
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_pairwise_distance.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanPairwiseDistance(Metric):
1.5955...
"""

_state_dict_all_req_keys = ("_sum_of_distances", "_num_examples")

def __init__(
self,
p: int = 2,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/mean_squared_error.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ class MeanSquaredError(Metric):
3.828125
"""

_state_dict_all_req_keys = ("_sum_of_squared_errors", "_num_examples")

@reinit__is_reduced
def reset(self) -> None:
self._sum_of_squared_errors = torch.tensor(0.0, device=self._device)
Expand Down
54 changes: 52 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,15 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Mapping
from functools import wraps
from numbers import Number
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch

import ignite.distributed as idist

from ignite.base.mixins import Serializable
from ignite.engine import CallableEventWithFilter, Engine, Events

if TYPE_CHECKING:
Expand Down Expand Up @@ -216,7 +219,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)


class Metric(metaclass=ABCMeta):
class Metric(Serializable, metaclass=ABCMeta):
"""
Base class for all Metrics.

Expand Down Expand Up @@ -546,6 +549,53 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.

If there's an active distributed configuration, some collective operations is done and
the list of values across ranks is saved under each attribute's name in the dict.
"""
state = OrderedDict()
for attr_name in self._state_dict_all_req_keys:
if attr_name not in self.__dict__:
raise ValueError(
f"Found a value in _state_dict_all_req_keys that is not among metric attributes: {attr_name}"
)
attr = getattr(self, attr_name)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if not isinstance(attr, (int, float, torch.Tensor)):
raise TypeError(
"Currently, only numeric or tensor-typed attributes of the metric"
" could be added to its state_dict."
)
if idist.get_world_size() == 1:
state[attr_name] = [attr]
else:
if isinstance(attr, (int, float)):
attr_type = type(attr)
attr = float(attr)
gathered_attr = cast(List[Any], idist.all_gather(attr))
if isinstance(attr, float):
gathered_attr = [attr_type(process_attr) for process_attr in gathered_attr]
state[attr_name] = gathered_attr

return state

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.

If there's an active distributed configuration, the process uses its rank to pick the proper value from
the list of values saved under each attribute's name in the dict.

Args:
state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys`
attribute.
"""
super().load_state_dict(state_dict)
rank = idist.get_rank()
for attr in self._state_dict_all_req_keys:
setattr(self, attr, state_dict[attr][rank])

def __add__(self, other: Any) -> "MetricsLambda":
from ignite.metrics.metrics_lambda import MetricsLambda

Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/multilabel_confusion_matrix.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ class MultiLabelConfusionMatrix(Metric):

"""

_state_dict_all_req_keys = ("confusion_matrix", "_num_examples")

def __init__(
self,
num_classes: int,
Expand Down
5 changes: 5 additions & 0 deletions ignite/metrics/nlp/bleu.py
Original file line number Diff line number Diff line change
Expand Up @@ -147,6 +147,11 @@ def __init__(
raise ValueError(f'Average must be either "macro" or "micro" (got: {average})')
self.average = average

if average == "micro":
self._state_dict_all_req_keys = ("p_numerators", "p_denominators", "hyp_length_sum", "ref_length_sum")
else:
self._state_dict_all_req_keys = ("_sum_of_bleu", "_num_sentences")

super(Bleu, self).__init__(output_transform=output_transform, device=device)

def _n_gram_counter(
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/nlp/rouge.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,6 +119,8 @@ class _BaseRouge(Metric):
Rouge interface for Rouge-L and Rouge-N
"""

_state_dict_all_req_keys = ("_recall", "_precision", "_fmeasure", "_num_examples")

def __init__(
self,
multiref: str = "average",
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/precision.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,8 @@


class _BasePrecisionRecall(_BaseClassification):
_state_dict_all_req_keys = ("_numerator", "_denominator", "_weight")

def __init__(
self,
output_transform: Callable = lambda x: x,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/psnr.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def get_y_channel(output):
.. versionadded:: 0.4.3
"""

_state_dict_all_req_keys = ("_sum_of_batchwise_psnr", "_num_examples")

def __init__(
self,
data_range: Union[int, float],
Expand Down
3 changes: 3 additions & 0 deletions ignite/metrics/running_average.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,6 +87,9 @@ def log_running_avg_metrics():
"""

required_output_keys = None
# TODO Shall we put `src` here? Then we should add a new branch for metric-typed attributes in `state_dict`
# and `load_state_dict`. Examples; This class; `Rouge` which has a `List[_BaseRouge]`.
_state_dict_all_req_keys = ("_value",)
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@vfdev-5 , What's your thoughts on this?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's add src and in state_dict/load_state_dict put an explicit check for isinstance(attr, Metric) and call state[attr_name].update(attr.state_dict) or something like that.
Do you have any other options ?


def __init__(
self,
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/ssim.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,8 @@ class SSIM(Metric):
.. versionadded:: 0.4.2
"""

_state_dict_all_req_keys = ("_sum_of_ssim", "_num_examples", "_kernel")

def __init__(
self,
data_range: Union[int, float],
Expand Down
2 changes: 2 additions & 0 deletions ignite/metrics/top_k_categorical_accuracy.py
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,8 @@ def one_hot_to_binary_output_transform(output):
0.75
"""

_state_dict_all_req_keys = ("_num_correct", "_num_examples")

def __init__(
self,
k: int = 5,
Expand Down
74 changes: 74 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -722,6 +723,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -744,6 +746,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.multinode_distributed
Expand Down Expand Up @@ -1125,3 +1128,74 @@ def update(self, output):

with pytest.raises(ValueError, match=r"Output should have 2 items of the same length"):
engine.run([0] * 10)


class DummyMetric4(Metric):
_state_dict_all_req_keys = ("dnumber", "fnumber", "tensor")

def __init__(self, value: int):
super().reset()
self.dnumber = value
self.fnumber = float(value + 1)
self.tensor = torch.tensor([value + 2])

def reset(self):
self.dnumber = -1
self.fnumber = -2.0
self.tensor = torch.tensor([-3])

def update(self, output):
pass

def compute(self):
pass


def test_wrong_state_dict():
class WrongMetric(Metric):
_state_dict_all_req_keys = ("object",)

def __init__(self, value):
super().__init__()
self.object = {"a": [value]}

def reset(self):
pass

def update(self, output):
pass

def compute(self):
pass

metric = WrongMetric(2)
with pytest.raises(TypeError, match="Currently, only numeric or tensor-typed attributes of the metric"):
metric.state_dict()

delattr(metric, "object")
with pytest.raises(ValueError, match="Found a value in _state_dict_all_req_keys that is not among"):
metric.state_dict()


def test_state_dict():
metric = DummyMetric4(1)
state = metric.state_dict()
assert state.keys() == {"dnumber", "fnumber", "tensor"}
metric.reset()
metric.load_state_dict(state)
assert metric.dnumber == 1
assert metric.fnumber == 2
assert metric.tensor == torch.tensor([3])


def _test_distrib_state_dict(device):
rank = idist.get_local_rank()
metric = DummyMetric4(rank)
state = metric.state_dict()
assert isinstance(state["dnumber"][rank], int)
assert isinstance(state["fnumber"][rank], float)
metric.reset()
metric.load_state_dict(state)
assert metric.dnumber == rank and isinstance(metric.dnumber, int)
assert metric.fnumber == rank + 1 and isinstance(metric.fnumber, float)
assert metric.tensor == torch.tensor([rank + 2])