-
Notifications
You must be signed in to change notification settings - Fork 54
/
Copy pathpid.py
100 lines (90 loc) · 4.38 KB
/
pid.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import torch
from torch.optim.optimizer import Optimizer, required
class PIDOptimizer(Optimizer):
r"""Implements stochastic gradient descent (optionally with momentum).
Nesterov momentum is based on the formula from
`On the importance of initialization and momentum in deep learning`__.
Args:
params (iterable): iterable of parameters to optimize or dicts defining
parameter groups
lr (float): learning rate
momentum (float, optional): momentum factor (default: 0)
weight_decay (float, optional): weight decay (L2 penalty) (default: 0)
dampening (float, optional): dampening for momentum (default: 0)
nesterov (bool, optional): enables Nesterov momentum (default: False)
Example:
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> optimizer.zero_grad()
>>> loss_fn(model(input), target).backward()
>>> optimizer.step()
__ http://www.cs.toronto.edu/%7Ehinton/absps/momentum.pdf
.. note::
The implementation of SGD with Momentum/Nesterov subtly differs from
Sutskever et. al. and implementations in some other frameworks.
Considering the specific case of Momentum, the update can be written as
.. math::
v = \rho * v + g \\
p = p - lr * v
where p, g, v and :math:`\rho` denote the parameters, gradient,
velocity, and momentum respectively.
This is in contrast to Sutskever et. al. and
other frameworks which employ an update of the form
.. math::
v = \rho * v + lr * g \\
p = p - v
The Nesterov version is analogously modified.
"""
def __init__(self, params, lr=required, momentum=0, dampening=0,
weight_decay=0, nesterov=False, I=5., D=10.):
defaults = dict(lr=lr, momentum=momentum, dampening=dampening,
weight_decay=weight_decay, nesterov=nesterov, I=I, D=D)
if nesterov and (momentum <= 0 or dampening != 0):
raise ValueError("Nesterov momentum requires a momentum and zero dampening")
super(PIDOptimizer, self).__init__(params, defaults)
def __setstate__(self, state):
super(PIDOptimizer, self).__setstate__(state)
for group in self.param_groups:
group.setdefault('nesterov', False)
def step(self, closure=None):
"""Performs a single optimization step.
Arguments:
closure (callable, optional): A closure that reevaluates the model
and returns the loss.
"""
loss = None
if closure is not None:
loss = closure()
for group in self.param_groups:
weight_decay = group['weight_decay']
momentum = group['momentum']
dampening = group['dampening']
nesterov = group['nesterov']
I = group['I']
D = group['D']
for p in group['params']:
if p.grad is None:
continue
d_p = p.grad.data
if weight_decay != 0:
d_p.add_(weight_decay, p.data)
if momentum != 0:
param_state = self.state[p]
if 'I_buffer' not in param_state:
I_buf = param_state['I_buffer'] = torch.zeros_like(p.data)
I_buf.mul_(momentum).add_(d_p)
else:
I_buf = param_state['I_buffer']
I_buf.mul_(momentum).add_(1 - dampening, d_p)
if 'grad_buffer' not in param_state:
g_buf = param_state['grad_buffer'] = torch.zeros_like(p.data)
g_buf = d_p
D_buf = param_state['D_buffer'] = torch.zeros_like(p.data)
D_buf.mul_(momentum).add_(d_p-g_buf)
else:
D_buf = param_state['D_buffer']
g_buf = param_state['grad_buffer']
D_buf.mul_(momentum).add_(1-momentum, d_p-g_buf)
self.state[p]['grad_buffer']= d_p.clone()
d_p = d_p.add_(I, I_buf).add_(D, D_buf)
p.data.add_(-group['lr'], d_p)
return loss