@@ -251,15 +251,11 @@ module function predict_1d(self, input) result(res)
251251 num_layers = size (self % layers)
252252
253253 ! predict is run in inference mode only;
254- ! set all dropout layers' training mode to false.
255- do n = 2 , num_layers
256- select type (this_layer = > self % layers(n) % p)
257- type is (dropout_layer)
258- this_layer % training = .false.
259- end select
260- end do
261-
254+ ! set all dropout layers' training mode to false, and
255+ ! return to training mode after inference.
256+ call self % set_training_mode(.false. )
262257 call self % forward(input)
258+ call self % set_training_mode(.true. )
263259
264260 select type (output_layer = > self % layers(num_layers) % p)
265261 type is (dense_layer)
@@ -269,7 +265,8 @@ module function predict_1d(self, input) result(res)
269265 type is (flatten_layer)
270266 res = output_layer % output
271267 class default
272- error stop ' network % output not implemented for this output layer'
268+ error stop ' network % output not implemented for ' // &
269+ trim (self % layers(num_layers) % name) // ' layer'
273270 end select
274271
275272 end function predict_1d
@@ -279,15 +276,25 @@ module function predict_2d(self, input) result(res)
279276 class(network), intent (in out ) :: self
280277 real , intent (in ) :: input(:,:)
281278 real , allocatable :: res(:)
282- integer :: num_layers
279+ integer :: n, num_layers
283280
284281 num_layers = size (self % layers)
285282
283+ ! predict is run in inference mode only;
284+ ! set all dropout layers' training mode to false, and
285+ ! return to training mode after inference.
286+ call self % set_training_mode(.false. )
286287 call self % forward(input)
288+ call self % set_training_mode(.true. )
287289
288290 select type (output_layer = > self % layers(num_layers) % p)
289291 type is (dense_layer)
290292 res = output_layer % output
293+ type is (flatten_layer)
294+ res = output_layer % output
295+ class default
296+ error stop ' network % output not implemented for ' // &
297+ trim (self % layers(num_layers) % name) // ' layer'
291298 end select
292299
293300 end function predict_2d
@@ -302,15 +309,11 @@ module function predict_3d(self, input) result(res)
302309 num_layers = size (self % layers)
303310
304311 ! predict is run in inference mode only;
305- ! set all dropout layers' training mode to false.
306- do n = 2 , num_layers
307- select type (this_layer = > self % layers(n) % p)
308- type is (dropout_layer)
309- this_layer % training = .false.
310- end select
311- end do
312-
312+ ! set all dropout layers' training mode to false, and
313+ ! return to training mode after inference.
314+ call self % set_training_mode(.false. )
313315 call self % forward(input)
316+ call self % set_training_mode(.true. )
314317
315318 select type (output_layer = > self % layers(num_layers) % p)
316319 type is (conv2d_layer)
@@ -321,7 +324,8 @@ module function predict_3d(self, input) result(res)
321324 type is (flatten_layer)
322325 res = output_layer % output
323326 class default
324- error stop ' network % output not implemented for this output layer'
327+ error stop ' network % output not implemented for ' // &
328+ trim (self % layers(num_layers) % name) // ' layer'
325329 end select
326330
327331 end function predict_3d
@@ -338,13 +342,9 @@ module function predict_batch_1d(self, input) result(res)
338342 output_size = product (self % layers(num_layers) % layer_shape)
339343
340344 ! predict is run in inference mode only;
341- ! set all dropout layers' training mode to false.
342- do n = 2 , num_layers
343- select type (this_layer = > self % layers(n) % p)
344- type is (dropout_layer)
345- this_layer % training = .false.
346- end select
347- end do
345+ ! set all dropout layers' training mode to false, and
346+ ! return to training mode after inference.
347+ call self % set_training_mode(.false. )
348348
349349 allocate (res(output_size, batch_size))
350350
@@ -358,11 +358,16 @@ module function predict_batch_1d(self, input) result(res)
358358 type is (flatten_layer)
359359 res(:,i) = output_layer % output
360360 class default
361- error stop ' network % output not implemented for this output layer'
361+ error stop ' network % output not implemented for ' // &
362+ trim (self % layers(num_layers) % name) // ' layer'
362363 end select
363364
364365 end do batch
365366
367+ ! We are now done with inference;
368+ ! return to training mode for dropout layers.
369+ call self % set_training_mode(.true. )
370+
366371 end function predict_batch_1d
367372
368373
@@ -377,13 +382,9 @@ module function predict_batch_3d(self, input) result(res)
377382 output_size = product (self % layers(num_layers) % layer_shape)
378383
379384 ! predict is run in inference mode only;
380- ! set all dropout layers' training mode to false.
381- do n = 2 , num_layers
382- select type (this_layer = > self % layers(n) % p)
383- type is (dropout_layer)
384- this_layer % training = .false.
385- end select
386- end do
385+ ! set all dropout layers' training mode to false, and
386+ ! return to training mode after inference.
387+ call self % set_training_mode(.false. )
387388
388389 allocate (res(output_size, batch_size))
389390
@@ -400,11 +401,16 @@ module function predict_batch_3d(self, input) result(res)
400401 type is (flatten_layer)
401402 res(:,i) = output_layer % output
402403 class default
403- error stop ' network % output not implemented for this output layer'
404+ error stop ' network % output not implemented for ' // &
405+ trim (self % layers(num_layers) % name) // ' layer'
404406 end select
405407
406408 end do batch
407409
410+ ! We are now done with inference;
411+ ! return to training mode for dropout layers.
412+ call self % set_training_mode(.true. )
413+
408414 end function predict_batch_3d
409415
410416
@@ -484,6 +490,18 @@ module subroutine set_params(self, params)
484490 end subroutine set_params
485491
486492
493+ module subroutine set_training_mode (self , training )
494+ class(network), intent (in out ) :: self
495+ logical , intent (in ) :: training
496+ integer :: n
497+ do n = 2 , size (self % layers)
498+ select type (this_layer = > self % layers(n) % p); type is(dropout_layer)
499+ this_layer % training = training
500+ end select
501+ end do
502+ end subroutine set_training_mode
503+
504+
487505 module subroutine train (self , input_data , output_data , batch_size , &
488506 epochs , optimizer , loss )
489507 class(network), intent (in out ) :: self
0 commit comments