-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathclassic_sgd.py
138 lines (114 loc) · 3.78 KB
/
classic_sgd.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
import torch
from torch import Tensor
from typing import List, Optional
class ClassicSGD(torch.optim.SGD):
"""
SGD implementation based directly on the formula proposed by Sutskever et. al.
"""
def __init__(
self,
params,
lr,
momentum=0,
dampening=0,
weight_decay=0,
nesterov=False,
):
super(ClassicSGD, self).__init__(
params,
lr=lr,
momentum=momentum,
dampening=dampening,
weight_decay=weight_decay,
nesterov=nesterov,
)
@torch.no_grad()
def step(self, closure=None):
"""Overwrite step to change velocity and param update formula
Args:
closure (callable, optional): Evaluates model and returns Loss. Defaults to None.
Returns:
loss: loss from closure. Defaults to None
"""
loss = None
if closure is not None:
with torch.enable_grad():
loss = closure()
for group in self.param_groups:
params_with_grad = []
d_p_list = []
momentum_buffer_list = []
weight_decay = group["weight_decay"]
momentum = group["momentum"]
dampening = group["dampening"]
nesterov = group["nesterov"]
lr = group["lr"]
try:
maximize = group["maximize"]
except:
maximize = False # set maximize to false if not found
for p in group["params"]:
if p.grad is not None:
params_with_grad.append(p)
d_p_list.append(p.grad)
state = self.state[p]
if "momentum_buffer" not in state:
momentum_buffer_list.append(None)
else:
momentum_buffer_list.append(state["momentum_buffer"])
sgd(
params_with_grad,
d_p_list,
momentum_buffer_list,
weight_decay=weight_decay,
momentum=momentum,
lr=lr,
dampening=dampening,
nesterov=nesterov,
maximize=maximize,
)
# update momentum_buffers in state
for p, momentum_buffer in zip(params_with_grad, momentum_buffer_list):
state = self.state[p]
state["momentum_buffer"] = momentum_buffer
return loss
def sgd(
params: List[Tensor],
d_p_list: List[Tensor],
momentum_buffer_list: List[Optional[Tensor]],
*,
weight_decay: float,
momentum: float,
lr: float,
dampening: float,
nesterov: bool,
maximize: bool
):
"""
Extracted Functional api from torch library:
https://github.com/pytorch/pytorch/blob/5fdcc20d8d96a6b42387f57c2ce331516ad94228/torch/optim/_functional.py#L156
Modified to handle learning rate and momentum differently.
"""
for i, param in enumerate(params):
d_p = d_p_list[i]
if weight_decay != 0:
d_p = d_p.add(param, alpha=weight_decay)
# computing g * lr
d_p = d_p.mul_(lr)
if momentum != 0:
buf = momentum_buffer_list[i]
if buf is None:
buf = torch.clone(d_p).detach()
momentum_buffer_list[i] = buf
else:
# original: buf * momentum + g
# new: buf * momentum + g * lr
buf.mul_(momentum).add_(d_p, alpha=1 - dampening)
if nesterov:
d_p = d_p.add(buf, alpha=momentum)
else:
d_p = buf
# original: p = p - lr * v
# new: p = p - v
alpha = 1 if maximize else -1
param.add_(d_p, alpha=alpha)