@@ -44,6 +44,7 @@ end subroutine minimize
4444 real :: momentum = 0
4545 logical :: nesterov = .false.
4646 real , allocatable , private :: velocity(:)
47+ integer , private :: start_index = 1
4748 contains
4849 procedure :: init = > init_sgd
4950 procedure :: minimize = > minimize_sgd
@@ -59,6 +60,7 @@ end subroutine minimize
5960 real :: decay_rate = 0.9
6061 real :: epsilon = 1e-8
6162 real , allocatable , private :: rms_gradient(:)
63+ integer , private :: start_index = 1
6264 contains
6365 procedure :: init = > init_rmsprop
6466 procedure :: minimize = > minimize_rmsprop
@@ -82,6 +84,7 @@ end subroutine minimize
8284 real :: weight_decay_decoupled = 0 ! decoupled weight decay regularization (AdamW)
8385 real , allocatable , private :: m(:), v(:)
8486 integer , private :: t = 0
87+ integer , private :: start_index = 1
8588 contains
8689 procedure :: init = > init_adam
8790 procedure :: minimize = > minimize_adam
@@ -99,6 +102,7 @@ end subroutine minimize
99102 real :: learning_rate_decay = 0
100103 real , allocatable , private :: sum_squared_gradient(:)
101104 integer , private :: t = 0
105+ integer , private :: start_index = 1
102106 contains
103107 procedure :: init = > init_adagrad
104108 procedure :: minimize = > minimize_adagrad
@@ -121,19 +125,38 @@ pure subroutine minimize_sgd(self, param, gradient)
121125 ! ! update rule.
122126 class(sgd), intent (inout ) :: self
123127 real , intent (inout ) :: param(:)
124- real , intent (in ) :: gradient(:)
128+ real , intent (in ) :: gradient(:) ! Always the same size as param
129+ integer :: end_index
125130
126131 if (self % momentum > 0 ) then
132+
133+ ! end_index is part of the bookkeeping for updating velocity because each
134+ ! batch update makes two calls to minimize, one for the weights and one for
135+ ! the biases.
136+ ! We use start_index and end_index to update the appropriate sections
137+ ! of the velocity array.
138+ end_index = self % start_index + size (param) - 1
139+
127140 ! Apply momentum update
128- self % velocity = self % momentum * self % velocity &
141+ self % velocity(self % start_index:end_index) = &
142+ self % momentum * self % velocity(self % start_index:end_index) &
129143 - self % learning_rate * gradient
130144 if (self % nesterov) then
131145 ! Apply Nesterov update
132- param = param + self % momentum * self % velocity &
146+ param = param + self % momentum * self % velocity(self % start_index:end_index) &
133147 - self % learning_rate * gradient
134148 else
135- param = param + self % velocity
149+ param = param + self % velocity(self % start_index:end_index)
150+ end if
151+
152+ if (self % start_index == 1 ) then
153+ ! We updated the weights part, now we shift forward for the biases part
154+ self % start_index = end_index + 1
155+ else
156+ ! We updated the biases part, now we shift back to start for the next batch
157+ self % start_index = 1
136158 end if
159+
137160 else
138161 ! Apply regular update
139162 param = param - self % learning_rate * gradient
@@ -157,14 +180,27 @@ pure subroutine minimize_rmsprop(self, param, gradient)
157180 class(rmsprop), intent (inout ) :: self
158181 real , intent (inout ) :: param(:)
159182 real , intent (in ) :: gradient(:)
183+ integer :: end_index
184+
185+ end_index = self % start_index + size (param) - 1
160186
161187 ! Compute the RMS of the gradient using the RMSProp rule
162- self % rms_gradient = self % decay_rate * self % rms_gradient &
188+ self % rms_gradient(self % start_index:end_index) = &
189+ self % decay_rate * self % rms_gradient(self % start_index:end_index) &
163190 + (1 - self % decay_rate) * gradient** 2
164191
165192 ! Update the network parameters based on the new RMS of the gradient
166193 param = param - self % learning_rate &
167- / sqrt (self % rms_gradient + self % epsilon) * gradient
194+ / sqrt (self % rms_gradient(self % start_index:end_index) + self % epsilon) &
195+ * gradient
196+
197+ if (self % start_index == 1 ) then
198+ ! We updated the weights part, now we shift forward for the biases part
199+ self % start_index = end_index + 1
200+ else
201+ ! We updated the biases part, now we shift back to start for the next batch
202+ self % start_index = 1
203+ end if
168204
169205 end subroutine minimize_rmsprop
170206
@@ -185,20 +221,27 @@ pure subroutine minimize_adam(self, param, gradient)
185221 class(adam), intent (inout ) :: self
186222 real , intent (inout ) :: param(:)
187223 real , intent (in ) :: gradient(:)
224+ integer :: end_index
225+
226+ end_index = self % start_index + size (param) - 1
188227
189228 self % t = self % t + 1
190229
191230 ! If weight_decay_l2 > 0, use L2 regularization;
192231 ! otherwise, default to regular Adam.
193232 associate(g = > gradient + self % weight_decay_l2 * param)
194- self % m = self % beta1 * self % m + (1 - self % beta1) * g
195- self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
233+ self % m(self % start_index:end_index) = &
234+ self % beta1 * self % m(self % start_index:end_index) &
235+ + (1 - self % beta1) * g
236+ self % v(self % start_index:end_index) = &
237+ self % beta2 * self % v(self % start_index:end_index) &
238+ + (1 - self % beta2) * g** 2
196239 end associate
197240
198241 ! Compute bias-corrected first and second moment estimates.
199242 associate( &
200- m_hat = > self % m / (1 - self % beta1** self % t), &
201- v_hat = > self % v / (1 - self % beta2** self % t) &
243+ m_hat = > self % m(self % start_index:end_index) / (1 - self % beta1** self % t), &
244+ v_hat = > self % v(self % start_index:end_index) / (1 - self % beta2** self % t) &
202245 )
203246
204247 ! Update parameters.
@@ -208,6 +251,14 @@ pure subroutine minimize_adam(self, param, gradient)
208251
209252 end associate
210253
254+ if (self % start_index == 1 ) then
255+ ! We updated the weights part, now we shift forward for the biases part
256+ self % start_index = end_index + 1
257+ else
258+ ! We updated the biases part, now we shift back to start for the next batch
259+ self % start_index = 1
260+ end if
261+
211262 end subroutine minimize_adam
212263
213264
@@ -226,6 +277,9 @@ pure subroutine minimize_adagrad(self, param, gradient)
226277 class(adagrad), intent (inout ) :: self
227278 real , intent (inout ) :: param(:)
228279 real , intent (in ) :: gradient(:)
280+ integer :: end_index
281+
282+ end_index = self % start_index + size (param) - 1
229283
230284 ! Update the current time step
231285 self % t = self % t + 1
@@ -239,13 +293,23 @@ pure subroutine minimize_adagrad(self, param, gradient)
239293 / (1 + (self % t - 1 ) * self % learning_rate_decay) &
240294 )
241295
242- self % sum_squared_gradient = self % sum_squared_gradient + g** 2
296+ self % sum_squared_gradient(self % start_index:end_index) = &
297+ self % sum_squared_gradient(self % start_index:end_index) + g** 2
243298
244- param = param - learning_rate * g / (sqrt (self % sum_squared_gradient) &
299+ param = param - learning_rate * g &
300+ / (sqrt (self % sum_squared_gradient(self % start_index:end_index)) &
245301 + self % epsilon)
246302
247303 end associate
248304
305+ if (self % start_index == 1 ) then
306+ ! We updated the weights part, now we shift forward for the biases part
307+ self % start_index = end_index + 1
308+ else
309+ ! We updated the biases part, now we shift back to start for the next batch
310+ self % start_index = 1
311+ end if
312+
249313 end subroutine minimize_adagrad
250314
251315end module nf_optimizers
0 commit comments