Skip to content

Commit 796ae74

Browse files
committed
Enable forward pass for dropout; backward pass TODO
1 parent 75ef184 commit 796ae74

File tree

5 files changed

+51
-3
lines changed

5 files changed

+51
-3
lines changed

src/nf/nf_dropout_layer_submodule.f90

Lines changed: 8 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -35,17 +35,24 @@ end subroutine init
3535
module subroutine forward(self, input)
3636
class(dropout_layer), intent(in out) :: self
3737
real, intent(in) :: input(:)
38+
real :: scale
3839

3940
! Generate random mask for dropout
4041
call random_number(self % mask)
4142
where (self % mask < self % dropout_rate)
4243
self % mask = 0
4344
elsewhere
44-
self % mask = 1 / (1 - self % dropout_rate) ! Scale to preserve expected value
45+
self % mask = 1
4546
end where
4647

4748
! Apply dropout mask
4849
self % output = input * self % mask
50+
51+
! Scale output and mask to preserve the input sum
52+
scale = sum(input) / sum(self % output)
53+
self % output = self % output * scale
54+
self % mask = self % mask * scale
55+
4956
end subroutine forward
5057

5158

src/nf/nf_layer.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -76,7 +76,7 @@ end subroutine backward_3d
7676

7777
interface
7878

79-
pure module subroutine forward(self, input)
79+
module subroutine forward(self, input)
8080
!! Apply a forward pass on the layer.
8181
!! This changes the internal state of the layer.
8282
!! This is normally called internally by the `network % forward`

src/nf/nf_layer_submodule.f90

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -107,7 +107,7 @@ pure module subroutine backward_3d(self, previous, gradient)
107107
end subroutine backward_3d
108108

109109

110-
pure module subroutine forward(self, input)
110+
module subroutine forward(self, input)
111111
implicit none
112112
class(layer), intent(in out) :: self
113113
class(layer), intent(in) :: input
@@ -126,6 +126,18 @@ pure module subroutine forward(self, input)
126126
call this_layer % forward(prev_layer % output)
127127
end select
128128

129+
type is(dropout_layer)
130+
131+
! Upstream layers permitted: input1d, dense, flatten
132+
select type(prev_layer => input % p)
133+
type is(input1d_layer)
134+
call this_layer % forward(prev_layer % output)
135+
type is(dense_layer)
136+
call this_layer % forward(prev_layer % output)
137+
type is(flatten_layer)
138+
call this_layer % forward(prev_layer % output)
139+
end select
140+
129141
type is(conv2d_layer)
130142

131143
! Upstream layers permitted: input3d, conv2d, maxpool2d, reshape3d

src/nf/nf_network_submodule.f90

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@
22

33
use nf_conv2d_layer, only: conv2d_layer
44
use nf_dense_layer, only: dense_layer
5+
use nf_dropout_layer, only: dropout_layer
56
use nf_flatten_layer, only: flatten_layer
67
use nf_input1d_layer, only: input1d_layer
78
use nf_input3d_layer, only: input3d_layer
@@ -227,6 +228,8 @@ module function predict_1d(self, input) result(res)
227228
select type(output_layer => self % layers(num_layers) % p)
228229
type is(dense_layer)
229230
res = output_layer % output
231+
type is(dropout_layer)
232+
res = output_layer % output
230233
type is(flatten_layer)
231234
res = output_layer % output
232235
class default

test/test_dropout_layer.f90

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -65,6 +65,32 @@ program test_dropout_layer
6565

6666
end select
6767

68+
! Now we're gonna run the forward pass and check that the dropout indeed
69+
! drops according to the requested dropout rate.
70+
forward_pass: block
71+
real :: input_data(5)
72+
real :: output_data(size(input_data))
73+
integer :: n
74+
75+
net = network([ &
76+
input(size(input_data)), &
77+
dropout(0.5) &
78+
])
79+
80+
call random_number(input_data)
81+
do n = 1, 10000
82+
output_data = net % predict(input_data)
83+
! Check that sum of output matches sum of input within small tolerance
84+
if (abs(sum(output_data) - sum(input_data)) > 1e-5) then
85+
ok = .false.
86+
exit
87+
end if
88+
end do
89+
if (.not. ok) then
90+
write(stderr, '(a)') 'dropout layer output sum should match input sum within 1% tolerance.. failed'
91+
end if
92+
end block forward_pass
93+
6894
if (ok) then
6995
print '(a)', 'test_dropout_layer: All tests passed.'
7096
else

0 commit comments

Comments
 (0)