Skip to content

Commit

Permalink
Exponential Parameter Scheduler (#715)
Browse files Browse the repository at this point in the history
Summary:
Pull Request resolved: facebookresearch/ClassyVision#715

Exponential parameter scheduler for fvcore/ClassyVision. This type of scheduler is useful as a LR scheduler when using EMA.

Differential Revision: D27058657

fbshipit-source-id: 176ee5bcd074d5d517afa660a7dbf86f83468979
  • Loading branch information
lauragustafson authored and facebook-github-bot committed Mar 15, 2021
1 parent 1f43d07 commit e16b3e8
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 0 deletions.
28 changes: 28 additions & 0 deletions fvcore/common/param_scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
"ParamScheduler",
"ConstantParamScheduler",
"CosineParamScheduler",
"ExponentialParamScheduler",
"LinearParamScheduler",
"CompositeParamScheduler",
"MultiStepParamScheduler",
Expand Down Expand Up @@ -85,6 +86,33 @@ def __call__(self, where: float) -> float:
)


class ExponentialParamScheduler(ParamScheduler):
"""
Exponetial schedule based on start value and decay, where value is caluated
for timestep t out of T total stepsm by
param_t = start_value * (decay ** t/T).The schedule is updated after every train
step by default based on the fraction of samples seen.
Example:
.. code-block:: python
ExponentialParamScheduler(start_value=2.0, decay=0.02)
Corresponds to a decreasing schedule with values in [2.0, 0.04).
"""

def __init__(
self,
start_value: float,
decay: float,
) -> None:
self._start_value = start_value
self._decay = decay

def __call__(self, where: float) -> float:
return self._start_value * (self._decay ** where)


class LinearParamScheduler(ParamScheduler):
"""
Linearly interpolates parameter between ``start_value`` and ``end_value``.
Expand Down
29 changes: 29 additions & 0 deletions tests/param_scheduler/test_scheduler_exponential.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
# Copyright (c) Facebook, Inc. and its affiliates.

import unittest

from fvcore.common.param_scheduler import ExponentialParamScheduler


class TestExponentialScheduler(unittest.TestCase):
_num_epochs = 10

def _get_valid_config(self):
return {"start_value": 2.0, "decay": 0.1}

def _get_valid_intermediate_values(self):
return [1.5887, 1.2619, 1.0024, 0.7962, 0.6325, 0.5024, 0.3991, 0.3170, 0.2518]

def test_scheduler(self):
config = self._get_valid_config()

scheduler = ExponentialParamScheduler(**config)
schedule = [
round(scheduler(epoch_num / self._num_epochs), 4)
for epoch_num in range(self._num_epochs)
]
expected_schedule = [
config["start_value"]
] + self._get_valid_intermediate_values()

self.assertEqual(schedule, expected_schedule)

0 comments on commit e16b3e8

Please sign in to comment.