@@ -221,10 +221,19 @@ module function predict_1d(self, input) result(res)
221221 class(network), intent (in out ) :: self
222222 real , intent (in ) :: input(:)
223223 real , allocatable :: res(:)
224- integer :: num_layers
224+ integer :: n, num_layers
225225
226226 num_layers = size (self % layers)
227227
228+ ! predict is run in inference mode only;
229+ ! set all dropout layers' training mode to false.
230+ do n = 2 , num_layers
231+ select type (this_layer = > self % layers(n) % p)
232+ type is (dropout_layer)
233+ this_layer % training = .false.
234+ end select
235+ end do
236+
228237 call self % forward(input)
229238
230239 select type (output_layer = > self % layers(num_layers) % p)
@@ -245,10 +254,19 @@ module function predict_3d(self, input) result(res)
245254 class(network), intent (in out ) :: self
246255 real , intent (in ) :: input(:,:,:)
247256 real , allocatable :: res(:)
248- integer :: num_layers
257+ integer :: n, num_layers
249258
250259 num_layers = size (self % layers)
251260
261+ ! predict is run in inference mode only;
262+ ! set all dropout layers' training mode to false.
263+ do n = 2 , num_layers
264+ select type (this_layer = > self % layers(n) % p)
265+ type is (dropout_layer)
266+ this_layer % training = .false.
267+ end select
268+ end do
269+
252270 call self % forward(input)
253271
254272 select type (output_layer = > self % layers(num_layers) % p)
0 commit comments