@@ -19,9 +19,7 @@ module nf_optimizers
1919 real :: learning_rate = 0.01
2020 contains
2121 procedure (init), deferred :: init
22- procedure (minimize_1d), deferred :: minimize_1d
23- procedure (minimize_2d), deferred :: minimize_2d
24- generic :: minimize = > minimize_1d, minimize_2d
22+ procedure (minimize), deferred :: minimize
2523 end type optimizer_base_type
2624
2725 abstract interface
@@ -32,19 +30,12 @@ impure elemental subroutine init(self, num_params)
3230 integer , intent (in ) :: num_params
3331 end subroutine init
3432
35- pure subroutine minimize_1d (self , param , gradient )
33+ pure subroutine minimize (self , param , gradient )
3634 import :: optimizer_base_type
3735 class(optimizer_base_type), intent (inout ) :: self
3836 real , intent (inout ) :: param(:)
3937 real , intent (in ) :: gradient(:)
40- end subroutine minimize_1d
41-
42- pure subroutine minimize_2d (self , param , gradient )
43- import :: optimizer_base_type
44- class(optimizer_base_type), intent (inout ) :: self
45- real , intent (inout ) :: param(:,:)
46- real , intent (in ) :: gradient(:,:)
47- end subroutine minimize_2d
38+ end subroutine minimize
4839
4940 end interface
5041
@@ -55,8 +46,7 @@ end subroutine minimize_2d
5546 real , allocatable , private :: velocity(:)
5647 contains
5748 procedure :: init = > init_sgd
58- procedure :: minimize_1d = > minimize_sgd_1d
59- procedure :: minimize_2d = > minimize_sgd_2d
49+ procedure :: minimize = > minimize_sgd
6050 end type sgd
6151
6252 type, extends(optimizer_base_type) :: rmsprop
@@ -71,8 +61,7 @@ end subroutine minimize_2d
7161 real , allocatable , private :: rms_gradient(:)
7262 contains
7363 procedure :: init = > init_rmsprop
74- procedure :: minimize_1d = > minimize_rmsprop_1d
75- procedure :: minimize_2d = > minimize_rmsprop_2d
64+ procedure :: minimize = > minimize_rmsprop
7665 end type rmsprop
7766
7867 type, extends(optimizer_base_type) :: adam
@@ -95,8 +84,7 @@ end subroutine minimize_2d
9584 integer , private :: t = 0
9685 contains
9786 procedure :: init = > init_adam
98- procedure :: minimize_1d = > minimize_adam_1d
99- procedure :: minimize_2d = > minimize_adam_2d
87+ procedure :: minimize = > minimize_adam
10088 end type adam
10189
10290 type, extends(optimizer_base_type) :: adagrad
@@ -113,8 +101,7 @@ end subroutine minimize_2d
113101 integer , private :: t = 0
114102 contains
115103 procedure :: init = > init_adagrad
116- procedure :: minimize_1d = > minimize_adagrad_1d
117- procedure :: minimize_2d = > minimize_adagrad_2d
104+ procedure :: minimize = > minimize_adagrad
118105 end type adagrad
119106
120107contains
@@ -129,7 +116,7 @@ impure elemental subroutine init_sgd(self, num_params)
129116 end subroutine init_sgd
130117
131118
132- pure subroutine minimize_sgd_1d (self , param , gradient )
119+ pure subroutine minimize_sgd (self , param , gradient )
133120 ! ! Concrete implementation of a stochastic gradient descent optimizer
134121 ! ! update rule.
135122 class(sgd), intent (inout ) :: self
@@ -152,33 +139,7 @@ pure subroutine minimize_sgd_1d(self, param, gradient)
152139 param = param - self % learning_rate * gradient
153140 end if
154141
155- end subroutine minimize_sgd_1d
156-
157-
158- pure subroutine minimize_sgd_2d (self , param , gradient )
159- ! ! Concrete implementation of a stochastic gradient descent optimizer
160- ! ! update rule for 2D arrays.
161- class(sgd), intent (inout ) :: self
162- real , intent (inout ) :: param(:,:)
163- real , intent (in ) :: gradient(:,:)
164-
165- if (self % momentum > 0 ) then
166- ! Apply momentum update
167- self % velocity = self % momentum * self % velocity &
168- - self % learning_rate * reshape (gradient, [size (gradient)])
169- if (self % nesterov) then
170- ! Apply Nesterov update
171- param = param + reshape (self % momentum * self % velocity &
172- - self % learning_rate * reshape (gradient, [size (gradient)]), shape (param))
173- else
174- param = param + reshape (self % velocity, shape (param))
175- end if
176- else
177- ! Apply regular update
178- param = param - self % learning_rate * gradient
179- end if
180-
181- end subroutine minimize_sgd_2d
142+ end subroutine minimize_sgd
182143
183144
184145 impure elemental subroutine init_rmsprop(self, num_params)
@@ -191,7 +152,7 @@ impure elemental subroutine init_rmsprop(self, num_params)
191152 end subroutine init_rmsprop
192153
193154
194- pure subroutine minimize_rmsprop_1d (self , param , gradient )
155+ pure subroutine minimize_rmsprop (self , param , gradient )
195156 ! ! Concrete implementation of a RMSProp optimizer update rule.
196157 class(rmsprop), intent (inout ) :: self
197158 real , intent (inout ) :: param(:)
@@ -205,24 +166,7 @@ pure subroutine minimize_rmsprop_1d(self, param, gradient)
205166 param = param - self % learning_rate &
206167 / sqrt (self % rms_gradient + self % epsilon) * gradient
207168
208- end subroutine minimize_rmsprop_1d
209-
210-
211- pure subroutine minimize_rmsprop_2d (self , param , gradient )
212- ! ! Concrete implementation of a RMSProp optimizer update rule for 2D arrays.
213- class(rmsprop), intent (inout ) :: self
214- real , intent (inout ) :: param(:,:)
215- real , intent (in ) :: gradient(:,:)
216-
217- ! Compute the RMS of the gradient using the RMSProp rule
218- self % rms_gradient = self % decay_rate * self % rms_gradient &
219- + (1 - self % decay_rate) * reshape (gradient, [size (gradient)])** 2
220-
221- ! Update the network parameters based on the new RMS of the gradient
222- param = param - self % learning_rate &
223- / sqrt (reshape (self % rms_gradient, shape (param)) + self % epsilon) * gradient
224-
225- end subroutine minimize_rmsprop_2d
169+ end subroutine minimize_rmsprop
226170
227171
228172 impure elemental subroutine init_adam(self, num_params)
@@ -236,7 +180,7 @@ impure elemental subroutine init_adam(self, num_params)
236180 end subroutine init_adam
237181
238182
239- pure subroutine minimize_adam_1d (self , param , gradient )
183+ pure subroutine minimize_adam (self , param , gradient )
240184 ! ! Concrete implementation of an Adam optimizer update rule.
241185 class(adam), intent (inout ) :: self
242186 real , intent (inout ) :: param(:)
@@ -264,38 +208,7 @@ pure subroutine minimize_adam_1d(self, param, gradient)
264208
265209 end associate
266210
267- end subroutine minimize_adam_1d
268-
269-
270- pure subroutine minimize_adam_2d (self , param , gradient )
271- ! ! Concrete implementation of an Adam optimizer update rule for 2D arrays.
272- class(adam), intent (inout ) :: self
273- real , intent (inout ) :: param(:,:)
274- real , intent (in ) :: gradient(:,:)
275-
276- self % t = self % t + 1
277-
278- ! If weight_decay_l2 > 0, use L2 regularization;
279- ! otherwise, default to regular Adam.
280- associate(g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]))
281- self % m = self % beta1 * self % m + (1 - self % beta1) * g
282- self % v = self % beta2 * self % v + (1 - self % beta2) * g** 2
283- end associate
284-
285- ! Compute bias-corrected first and second moment estimates.
286- associate( &
287- m_hat = > self % m / (1 - self % beta1** self % t), &
288- v_hat = > self % v / (1 - self % beta2** self % t) &
289- )
290-
291- ! Update parameters.
292- param = param &
293- - self % learning_rate * reshape (m_hat / (sqrt (v_hat) + self % epsilon), shape (param)) &
294- - self % learning_rate * self % weight_decay_decoupled * param
295-
296- end associate
297-
298- end subroutine minimize_adam_2d
211+ end subroutine minimize_adam
299212
300213
301214 impure elemental subroutine init_adagrad(self, num_params)
@@ -308,7 +221,7 @@ impure elemental subroutine init_adagrad(self, num_params)
308221 end subroutine init_adagrad
309222
310223
311- pure subroutine minimize_adagrad_1d (self , param , gradient )
224+ pure subroutine minimize_adagrad (self , param , gradient )
312225 ! ! Concrete implementation of an Adagrad optimizer update rule.
313226 class(adagrad), intent (inout ) :: self
314227 real , intent (inout ) :: param(:)
@@ -333,34 +246,6 @@ pure subroutine minimize_adagrad_1d(self, param, gradient)
333246
334247 end associate
335248
336- end subroutine minimize_adagrad_1d
337-
338-
339- pure subroutine minimize_adagrad_2d (self , param , gradient )
340- ! ! Concrete implementation of an Adagrad optimizer update rule for 2D arrays.
341- class(adagrad), intent (inout ) :: self
342- real , intent (inout ) :: param(:,:)
343- real , intent (in ) :: gradient(:,:)
344-
345- ! Update the current time step
346- self % t = self % t + 1
347-
348- associate( &
349- ! If weight_decay_l2 > 0, use L2 regularization;
350- ! otherwise, default to regular Adagrad.
351- g = > reshape (gradient, [size (gradient)]) + self % weight_decay_l2 * reshape (param, [size (param)]), &
352- ! Amortize the learning rate as function of the current time step.
353- learning_rate = > self % learning_rate &
354- / (1 + (self % t - 1 ) * self % learning_rate_decay) &
355- )
356-
357- self % sum_squared_gradient = self % sum_squared_gradient + g** 2
358-
359- param = param - learning_rate * reshape (g / (sqrt (self % sum_squared_gradient) &
360- + self % epsilon), shape (param))
361-
362- end associate
363-
364- end subroutine minimize_adagrad_2d
249+ end subroutine minimize_adagrad
365250
366251end module nf_optimizers
0 commit comments