Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make metrics serializable #3001

Merged
8 changes: 5 additions & 3 deletions docs/source/engine.rst
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,7 @@ Resuming the training
It is possible to resume the training from a checkpoint and approximately reproduce original run's behaviour.
Using Ignite, this can be easily done using :class:`~ignite.handlers.checkpoint.Checkpoint` handler. Engine provides two methods
to serialize and deserialize its internal state :meth:`~ignite.engine.engine.Engine.state_dict` and
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler etc user can
:meth:`~ignite.engine.engine.Engine.load_state_dict`. In addition to serializing model, optimizer, lr scheduler, metrics, etc., user can
store the trainer and then resume the training. For example:

.. code-block:: python
Expand All @@ -82,8 +82,9 @@ store the trainer and then resume the training. For example:
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_save = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
Copy link
Collaborator

@vfdev-5 vfdev-5 Aug 10, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ideally, we would like to get that in a more automatic way. Also if we want to save multiple metrics adding them one by one to save/load can be a pain.
Let's propose something in a follow-up PR

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Could you please reiterate what you mean? If we have multiple metrics, is it cumbersome to do:

to_save = {..., 'metric1': metric1, 'metric2': metric2, 'metric3': metric3}

?

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it is too verbose, let's try to find a way to simplify that.

handler = Checkpoint(to_save, DiskSaver('/tmp/training', create_dir=True))
trainer.add_event_handler(Events.EPOCH_COMPLETED, handler)
trainer.run(data_loader, max_epochs=100)
Expand All @@ -104,8 +105,9 @@ We can then restore the training from the last checkpoint.
optimizer = ...
lr_scheduler = ...
data_loader = ...
metric = ...

to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler}
to_load = {'trainer': trainer, 'model': model, 'optimizer': optimizer, 'lr_scheduler': lr_scheduler, 'metric': metric}
checkpoint = torch.load(checkpoint_file)
Checkpoint.load_objects(to_load=to_load, checkpoint=checkpoint)

Expand Down
2 changes: 1 addition & 1 deletion ignite/handlers/checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -710,7 +710,7 @@ def state_dict(self) -> "OrderedDict[str, List[Tuple[int, str]]]":
return OrderedDict([("saved", [(p, f) for p, f in self._saved])])

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replace internal state of the class with provided state dict data.
"""Method replaces internal state of the class with provided state dict data.

Args:
state_dict: a dict with "saved" key and list of ``(priority, filename)`` pairs as values.
Expand Down
1 change: 1 addition & 0 deletions ignite/metrics/accumulation.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ class VariableAccumulation(Metric):
"""

required_output_keys = None
_state_dict_all_req_keys = ("accumulator", "num_examples")

def __init__(
self,
Expand Down
43 changes: 41 additions & 2 deletions ignite/metrics/metric.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,16 @@
from abc import ABCMeta, abstractmethod
from collections import OrderedDict
from collections.abc import Mapping
from functools import wraps
from numbers import Number
from typing import Any, Callable, cast, Dict, Optional, Sequence, Tuple, TYPE_CHECKING, Union
from typing import Any, Callable, cast, Dict, List, Optional, Sequence, Tuple, TYPE_CHECKING, Union

import torch
from torch import distributed as torch_dist

import ignite.distributed as idist

from ignite.base.mixins import Serializable
from ignite.engine import CallableEventWithFilter, Engine, Events

if TYPE_CHECKING:
Expand Down Expand Up @@ -216,7 +220,7 @@ def __init__(self, *args: Any, **kwargs: Any) -> None:
)


class Metric(metaclass=ABCMeta):
class Metric(Serializable, metaclass=ABCMeta):
"""
Base class for all Metrics.

Expand Down Expand Up @@ -546,6 +550,41 @@ def is_attached(self, engine: Engine, usage: Union[str, MetricUsage] = EpochWise
usage = self._check_usage(usage)
return engine.has_event_handler(self.completed, usage.COMPLETED)

def state_dict(self) -> OrderedDict:
"""Method returns state dict with attributes of the metric specified in its
`_state_dict_all_req_keys` attribute. Can be used to save internal state of the class.

If there's an active distributed configuration, some collective operations is done and
the list of values across ranks is saved under each attribute's name in the dict.
"""
state = OrderedDict()
for attr_name in self._state_dict_all_req_keys:
attr = getattr(self, attr_name)
vfdev-5 marked this conversation as resolved.
Show resolved Hide resolved
if idist.get_world_size() == 1:
state[attr_name] = [attr]
else:
if isinstance(attr, (float, torch.Tensor)):
state[attr_name] = cast(List[Any], idist.all_gather(attr))
else:
state[attr_name] = [None] * idist.get_world_size()
torch_dist.all_gather_object(state[attr_name], attr)
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
return state

def load_state_dict(self, state_dict: Mapping) -> None:
"""Method replaces internal state of the class with provided state dict data.

If there's an active distributed configuration, the process uses its rank to pick the proper value from
the list of values saved under each attribute's name in the dict.

Args:
state_dict: a dict containing attributes of the metric specified in its `_state_dict_all_req_keys`
attribute.
"""
super().load_state_dict(state_dict)
rank = idist.get_local_rank()
sadra-barikbin marked this conversation as resolved.
Show resolved Hide resolved
for attr in self._state_dict_all_req_keys:
setattr(self, attr, state_dict[attr][rank])

def __add__(self, other: Any) -> "MetricsLambda":
from ignite.metrics.metrics_lambda import MetricsLambda

Expand Down
46 changes: 46 additions & 0 deletions tests/ignite/metrics/test_metric.py
Original file line number Diff line number Diff line change
Expand Up @@ -713,6 +713,7 @@ def test_distrib_nccl_gpu(distributed_context_single_node_nccl):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -722,6 +723,7 @@ def test_distrib_gloo_cpu_or_gpu(distributed_context_single_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.distributed
Expand All @@ -744,6 +746,7 @@ def test_multinode_distrib_gloo_cpu_or_gpu(distributed_context_multi_node_gloo):
_test_distrib_sync_all_reduce_decorator(device)
_test_invalid_sync_all_reduce(device)
_test_compute_with_sync_all_reduce_doesnt_change_attributes(device)
_test_distrib_state_dict(device)


@pytest.mark.multinode_distributed
Expand Down Expand Up @@ -1125,3 +1128,46 @@ def update(self, output):

with pytest.raises(ValueError, match=r"Output should have 2 items of the same length"):
engine.run([0] * 10)


class DummyMetric4(Metric):
_state_dict_all_req_keys = ("number", "tensor", "object")

def __init__(self, value):
super().reset()
self.number = value
self.tensor = torch.tensor([value + 1])
self.object = {"a": [value + 2]}

def reset(self):
self.number = -1
self.tensor = torch.tensor([-2])
self.object = {"a": [-3]}

def update(output):
pass

def compute():
pass


def test_state_dict():
metric = DummyMetric4(1)
state = metric.state_dict()
assert state.keys() == {"number", "tensor", "object"}
metric.reset()
metric.load_state_dict(state)
assert metric.number == 1
assert metric.tensor == torch.tensor([2])
assert metric.object == {"a": [3]}


def _test_distrib_state_dict(device):
rank = idist.get_local_rank()
metric = DummyMetric4(rank)
state = metric.state_dict()
metric.reset()
metric.load_state_dict(state)
assert metric.number == rank
assert metric.tensor == torch.tensor([rank + 1])
assert metric.object == {"a": [rank + 2]}