From e16b3e88609834442b7692dd1d56f03de5e65953 Mon Sep 17 00:00:00 2001 From: Laura Gustafson Date: Mon, 15 Mar 2021 13:32:10 -0700 Subject: [PATCH] Exponential Parameter Scheduler (#715) Summary: Pull Request resolved: https://github.com/facebookresearch/ClassyVision/pull/715 Exponential parameter scheduler for fvcore/ClassyVision. This type of scheduler is useful as a LR scheduler when using EMA. Differential Revision: D27058657 fbshipit-source-id: 176ee5bcd074d5d517afa660a7dbf86f83468979 --- fvcore/common/param_scheduler.py | 28 ++++++++++++++++++ .../test_scheduler_exponential.py | 29 +++++++++++++++++++ 2 files changed, 57 insertions(+) create mode 100644 tests/param_scheduler/test_scheduler_exponential.py diff --git a/fvcore/common/param_scheduler.py b/fvcore/common/param_scheduler.py index 5176538..6b82f0b 100644 --- a/fvcore/common/param_scheduler.py +++ b/fvcore/common/param_scheduler.py @@ -7,6 +7,7 @@ "ParamScheduler", "ConstantParamScheduler", "CosineParamScheduler", + "ExponentialParamScheduler", "LinearParamScheduler", "CompositeParamScheduler", "MultiStepParamScheduler", @@ -85,6 +86,33 @@ def __call__(self, where: float) -> float: ) +class ExponentialParamScheduler(ParamScheduler): + """ + Exponetial schedule based on start value and decay, where value is caluated + for timestep t out of T total stepsm by + param_t = start_value * (decay ** t/T).The schedule is updated after every train + step by default based on the fraction of samples seen. + + Example: + + .. code-block:: python + ExponentialParamScheduler(start_value=2.0, decay=0.02) + + Corresponds to a decreasing schedule with values in [2.0, 0.04). + """ + + def __init__( + self, + start_value: float, + decay: float, + ) -> None: + self._start_value = start_value + self._decay = decay + + def __call__(self, where: float) -> float: + return self._start_value * (self._decay ** where) + + class LinearParamScheduler(ParamScheduler): """ Linearly interpolates parameter between ``start_value`` and ``end_value``. diff --git a/tests/param_scheduler/test_scheduler_exponential.py b/tests/param_scheduler/test_scheduler_exponential.py new file mode 100644 index 0000000..4c94b97 --- /dev/null +++ b/tests/param_scheduler/test_scheduler_exponential.py @@ -0,0 +1,29 @@ +# Copyright (c) Facebook, Inc. and its affiliates. + +import unittest + +from fvcore.common.param_scheduler import ExponentialParamScheduler + + +class TestExponentialScheduler(unittest.TestCase): + _num_epochs = 10 + + def _get_valid_config(self): + return {"start_value": 2.0, "decay": 0.1} + + def _get_valid_intermediate_values(self): + return [1.5887, 1.2619, 1.0024, 0.7962, 0.6325, 0.5024, 0.3991, 0.3170, 0.2518] + + def test_scheduler(self): + config = self._get_valid_config() + + scheduler = ExponentialParamScheduler(**config) + schedule = [ + round(scheduler(epoch_num / self._num_epochs), 4) + for epoch_num in range(self._num_epochs) + ] + expected_schedule = [ + config["start_value"] + ] + self._get_valid_intermediate_values() + + self.assertEqual(schedule, expected_schedule)