Skip to content

Commit e9ba73e

Browse files
committed
Bookkeeping for velocity, rms_gradient, etc.; optimizer tests now pass
1 parent dc55df0 commit e9ba73e

File tree

1 file changed

+76
-12
lines changed

1 file changed

+76
-12
lines changed

src/nf/nf_optimizers.f90

Lines changed: 76 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,7 @@ end subroutine minimize
4444
real :: momentum = 0
4545
logical :: nesterov = .false.
4646
real, allocatable, private :: velocity(:)
47+
integer, private :: start_index = 1
4748
contains
4849
procedure :: init => init_sgd
4950
procedure :: minimize => minimize_sgd
@@ -59,6 +60,7 @@ end subroutine minimize
5960
real :: decay_rate = 0.9
6061
real :: epsilon = 1e-8
6162
real, allocatable, private :: rms_gradient(:)
63+
integer, private :: start_index = 1
6264
contains
6365
procedure :: init => init_rmsprop
6466
procedure :: minimize => minimize_rmsprop
@@ -82,6 +84,7 @@ end subroutine minimize
8284
real :: weight_decay_decoupled = 0 ! decoupled weight decay regularization (AdamW)
8385
real, allocatable, private :: m(:), v(:)
8486
integer, private :: t = 0
87+
integer, private :: start_index = 1
8588
contains
8689
procedure :: init => init_adam
8790
procedure :: minimize => minimize_adam
@@ -99,6 +102,7 @@ end subroutine minimize
99102
real :: learning_rate_decay = 0
100103
real, allocatable, private :: sum_squared_gradient(:)
101104
integer, private :: t = 0
105+
integer, private :: start_index = 1
102106
contains
103107
procedure :: init => init_adagrad
104108
procedure :: minimize => minimize_adagrad
@@ -121,19 +125,38 @@ pure subroutine minimize_sgd(self, param, gradient)
121125
!! update rule.
122126
class(sgd), intent(inout) :: self
123127
real, intent(inout) :: param(:)
124-
real, intent(in) :: gradient(:)
128+
real, intent(in) :: gradient(:) ! Always the same size as param
129+
integer :: end_index
125130

126131
if (self % momentum > 0) then
132+
133+
! end_index is part of the bookkeeping for updating velocity because each
134+
! batch update makes two calls to minimize, one for the weights and one for
135+
! the biases.
136+
! We use start_index and end_index to update the appropriate sections
137+
! of the velocity array.
138+
end_index = self % start_index + size(param) - 1
139+
127140
! Apply momentum update
128-
self % velocity = self % momentum * self % velocity &
141+
self % velocity(self % start_index:end_index) = &
142+
self % momentum * self % velocity(self % start_index:end_index) &
129143
- self % learning_rate * gradient
130144
if (self % nesterov) then
131145
! Apply Nesterov update
132-
param = param + self % momentum * self % velocity &
146+
param = param + self % momentum * self % velocity(self % start_index:end_index) &
133147
- self % learning_rate * gradient
134148
else
135-
param = param + self % velocity
149+
param = param + self % velocity(self % start_index:end_index)
150+
end if
151+
152+
if (self % start_index == 1) then
153+
! We updated the weights part, now we shift forward for the biases part
154+
self % start_index = end_index + 1
155+
else
156+
! We updated the biases part, now we shift back to start for the next batch
157+
self % start_index = 1
136158
end if
159+
137160
else
138161
! Apply regular update
139162
param = param - self % learning_rate * gradient
@@ -157,14 +180,27 @@ pure subroutine minimize_rmsprop(self, param, gradient)
157180
class(rmsprop), intent(inout) :: self
158181
real, intent(inout) :: param(:)
159182
real, intent(in) :: gradient(:)
183+
integer :: end_index
184+
185+
end_index = self % start_index + size(param) - 1
160186

161187
! Compute the RMS of the gradient using the RMSProp rule
162-
self % rms_gradient = self % decay_rate * self % rms_gradient &
188+
self % rms_gradient(self % start_index:end_index) = &
189+
self % decay_rate * self % rms_gradient(self % start_index:end_index) &
163190
+ (1 - self % decay_rate) * gradient**2
164191

165192
! Update the network parameters based on the new RMS of the gradient
166193
param = param - self % learning_rate &
167-
/ sqrt(self % rms_gradient + self % epsilon) * gradient
194+
/ sqrt(self % rms_gradient(self % start_index:end_index) + self % epsilon) &
195+
* gradient
196+
197+
if (self % start_index == 1) then
198+
! We updated the weights part, now we shift forward for the biases part
199+
self % start_index = end_index + 1
200+
else
201+
! We updated the biases part, now we shift back to start for the next batch
202+
self % start_index = 1
203+
end if
168204

169205
end subroutine minimize_rmsprop
170206

@@ -185,20 +221,27 @@ pure subroutine minimize_adam(self, param, gradient)
185221
class(adam), intent(inout) :: self
186222
real, intent(inout) :: param(:)
187223
real, intent(in) :: gradient(:)
224+
integer :: end_index
225+
226+
end_index = self % start_index + size(param) - 1
188227

189228
self % t = self % t + 1
190229

191230
! If weight_decay_l2 > 0, use L2 regularization;
192231
! otherwise, default to regular Adam.
193232
associate(g => gradient + self % weight_decay_l2 * param)
194-
self % m = self % beta1 * self % m + (1 - self % beta1) * g
195-
self % v = self % beta2 * self % v + (1 - self % beta2) * g**2
233+
self % m(self % start_index:end_index) = &
234+
self % beta1 * self % m(self % start_index:end_index) &
235+
+ (1 - self % beta1) * g
236+
self % v(self % start_index:end_index) = &
237+
self % beta2 * self % v(self % start_index:end_index) &
238+
+ (1 - self % beta2) * g**2
196239
end associate
197240

198241
! Compute bias-corrected first and second moment estimates.
199242
associate( &
200-
m_hat => self % m / (1 - self % beta1**self % t), &
201-
v_hat => self % v / (1 - self % beta2**self % t) &
243+
m_hat => self % m(self % start_index:end_index) / (1 - self % beta1**self % t), &
244+
v_hat => self % v(self % start_index:end_index) / (1 - self % beta2**self % t) &
202245
)
203246

204247
! Update parameters.
@@ -208,6 +251,14 @@ pure subroutine minimize_adam(self, param, gradient)
208251

209252
end associate
210253

254+
if (self % start_index == 1) then
255+
! We updated the weights part, now we shift forward for the biases part
256+
self % start_index = end_index + 1
257+
else
258+
! We updated the biases part, now we shift back to start for the next batch
259+
self % start_index = 1
260+
end if
261+
211262
end subroutine minimize_adam
212263

213264

@@ -226,6 +277,9 @@ pure subroutine minimize_adagrad(self, param, gradient)
226277
class(adagrad), intent(inout) :: self
227278
real, intent(inout) :: param(:)
228279
real, intent(in) :: gradient(:)
280+
integer :: end_index
281+
282+
end_index = self % start_index + size(param) - 1
229283

230284
! Update the current time step
231285
self % t = self % t + 1
@@ -239,13 +293,23 @@ pure subroutine minimize_adagrad(self, param, gradient)
239293
/ (1 + (self % t - 1) * self % learning_rate_decay) &
240294
)
241295

242-
self % sum_squared_gradient = self % sum_squared_gradient + g**2
296+
self % sum_squared_gradient(self % start_index:end_index) = &
297+
self % sum_squared_gradient(self % start_index:end_index) + g**2
243298

244-
param = param - learning_rate * g / (sqrt(self % sum_squared_gradient) &
299+
param = param - learning_rate * g &
300+
/ (sqrt(self % sum_squared_gradient(self % start_index:end_index)) &
245301
+ self % epsilon)
246302

247303
end associate
248304

305+
if (self % start_index == 1) then
306+
! We updated the weights part, now we shift forward for the biases part
307+
self % start_index = end_index + 1
308+
else
309+
! We updated the biases part, now we shift back to start for the next batch
310+
self % start_index = 1
311+
end if
312+
249313
end subroutine minimize_adagrad
250314

251315
end module nf_optimizers

0 commit comments

Comments
 (0)