-
Notifications
You must be signed in to change notification settings - Fork 2
/
warmup_lr.py
35 lines (26 loc) · 961 Bytes
/
warmup_lr.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import tensorflow as tf
class WarmUpExtension:
def __init__(self, warmup_steps, *args, **kwargs):
ws = kwargs.pop("warmup_steps", warmup_steps)
super().__init__(*args, **kwargs)
self.warmup_steps = ws
@tf.function
def __call__(self, step):
step = tf.cast(step, tf.float32)
if step < self.warmup_steps:
lr = super().__call__(0)
lr = lr * step / self.warmup_steps
else:
lr = super().__call__(step - self.warmup_steps)
return lr
def get_config(self):
config = super().get_config()
config.update(
{"warmup_steps": self.warmup_steps, }
)
return config
def extend_with_warmup_lr(base_scheduler):
class WarmupLrScheduler(WarmUpExtension, base_scheduler):
def __init__(self, warmup_steps, *args, **kwargs):
super().__init__(warmup_steps, *args, **kwargs)
return WarmupLrScheduler