Skip to content

Commit c984b15

Browse files
committed
disable dropout in inference mode (net % predict); TODO enable in net % train
1 parent 59cc7e1 commit c984b15

File tree

6 files changed

+53
-22
lines changed

6 files changed

+53
-22
lines changed

src/nf/nf_dropout_layer.f90

Lines changed: 4 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,7 @@ module nf_dropout_layer
2222

2323
real :: dropout_rate ! probability of dropping a neuron
2424
real :: scale ! scale factor to preserve the input sum
25-
logical :: training = .true.
25+
logical :: training = .false. ! set to .true. in training mode
2626

2727
contains
2828

@@ -33,11 +33,13 @@ module nf_dropout_layer
3333
end type dropout_layer
3434

3535
interface dropout_layer
36-
module function dropout_layer_cons(rate) &
36+
module function dropout_layer_cons(rate, training) &
3737
result(res)
3838
!! This function returns the `dropout_layer` instance.
3939
real, intent(in) :: rate
4040
!! Dropout rate
41+
logical, intent(in), optional :: training
42+
!! Training mode (default .false.)
4143
type(dropout_layer) :: res
4244
!! dropout_layer instance
4345
end function dropout_layer_cons

src/nf/nf_dropout_layer_submodule.f90

Lines changed: 22 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -4,12 +4,12 @@
44

55
contains
66

7-
module function dropout_layer_cons(rate) result(res)
7+
module function dropout_layer_cons(rate, training) result(res)
88
real, intent(in) :: rate
9+
logical, intent(in), optional :: training
910
type(dropout_layer) :: res
10-
11-
! Initialize dropout rate
1211
res % dropout_rate = rate
12+
if (present(training)) res % training = training
1313
end function dropout_layer_cons
1414

1515

@@ -36,19 +36,27 @@ module subroutine forward(self, input)
3636
class(dropout_layer), intent(in out) :: self
3737
real, intent(in) :: input(:)
3838

39-
! Generate random mask for dropout
40-
call random_number(self % mask)
41-
where (self % mask < self % dropout_rate)
42-
self % mask = 0
43-
elsewhere
44-
self % mask = 1
45-
end where
39+
! Generate random mask for dropout, training mode only
40+
if (self % training) then
41+
42+
call random_number(self % mask)
43+
where (self % mask < self % dropout_rate)
44+
self % mask = 0
45+
elsewhere
46+
self % mask = 1
47+
end where
48+
49+
! Scale factor to preserve the input sum
50+
self % scale = sum(input) / sum(input * self % mask)
51+
52+
! Apply dropout mask
53+
self % output = input * self % mask * self % scale
4654

47-
! Scale factor to preserve the input sum
48-
self % scale = sum(input) / sum(input * self % mask)
55+
else
56+
! In inference mode, we don't apply dropout; simply pass through the input
57+
self % output = input
4958

50-
! Apply dropout mask
51-
self % output = input * self % mask * self % scale
59+
end if
5260

5361
end subroutine forward
5462

src/nf/nf_layer_constructors.f90

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -85,7 +85,7 @@ module function dense(layer_size, activation) result(res)
8585
!! Resulting layer instance
8686
end function dense
8787

88-
module function dropout(rate) result(res)
88+
module function dropout(rate, training) result(res)
8989
!! Create a dropout layer with a given dropout rate.
9090
!!
9191
!! This layer is for randomly disabling neurons during training.
@@ -99,6 +99,8 @@ module function dropout(rate) result(res)
9999
!! ```
100100
real, intent(in) :: rate
101101
!! Dropout rate - fraction of neurons to randomly disable during training
102+
logical, intent(in), optional :: training
103+
!! Training mode (default .false.)
102104
type(layer) :: res
103105
!! Resulting layer instance
104106
end function dropout

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -64,11 +64,12 @@ module function dense(layer_size, activation) result(res)
6464
end function dense
6565

6666

67-
module function dropout(rate) result(res)
67+
module function dropout(rate, training) result(res)
6868
real, intent(in) :: rate
69+
logical, intent(in), optional :: training
6970
type(layer) :: res
7071
res % name = 'dropout'
71-
allocate(res % p, source=dropout_layer(rate))
72+
allocate(res % p, source=dropout_layer(rate, training))
7273
end function dropout
7374

7475

src/nf/nf_layer_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -304,7 +304,7 @@ elemental module function get_num_params(self) result(num_params)
304304
type is (dense_layer)
305305
num_params = this_layer % get_num_params()
306306
type is (dropout_layer)
307-
num_params = size(this_layer % mask)
307+
num_params = 0
308308
type is (conv2d_layer)
309309
num_params = this_layer % get_num_params()
310310
type is (maxpool2d_layer)

src/nf/nf_network_submodule.f90

Lines changed: 20 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -221,10 +221,19 @@ module function predict_1d(self, input) result(res)
221221
class(network), intent(in out) :: self
222222
real, intent(in) :: input(:)
223223
real, allocatable :: res(:)
224-
integer :: num_layers
224+
integer :: n, num_layers
225225

226226
num_layers = size(self % layers)
227227

228+
! predict is run in inference mode only;
229+
! set all dropout layers' training mode to false.
230+
do n = 2, num_layers
231+
select type(this_layer => self % layers(n) % p)
232+
type is(dropout_layer)
233+
this_layer % training = .false.
234+
end select
235+
end do
236+
228237
call self % forward(input)
229238

230239
select type(output_layer => self % layers(num_layers) % p)
@@ -245,10 +254,19 @@ module function predict_3d(self, input) result(res)
245254
class(network), intent(in out) :: self
246255
real, intent(in) :: input(:,:,:)
247256
real, allocatable :: res(:)
248-
integer :: num_layers
257+
integer :: n, num_layers
249258

250259
num_layers = size(self % layers)
251260

261+
! predict is run in inference mode only;
262+
! set all dropout layers' training mode to false.
263+
do n = 2, num_layers
264+
select type(this_layer => self % layers(n) % p)
265+
type is(dropout_layer)
266+
this_layer % training = .false.
267+
end select
268+
end do
269+
252270
call self % forward(input)
253271

254272
select type(output_layer => self % layers(num_layers) % p)

0 commit comments

Comments
 (0)