Skip to content

Commit 183e82f

Browse files
committed
Expand tests
1 parent 0350c7d commit 183e82f

File tree

1 file changed

+23
-17
lines changed

1 file changed

+23
-17
lines changed

test/test_dropout_layer.f90

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -112,7 +112,7 @@ program test_dropout_layer
112112
! Now we're gonna run the forward pass and check that the dropout indeed
113113
! drops according to the requested dropout rate.
114114
forward_pass: block
115-
real :: input_data(5)
115+
real :: input_data(4)
116116
real :: output_data(size(input_data))
117117
integer :: n
118118

@@ -121,43 +121,49 @@ program test_dropout_layer
121121
dropout(0.5) &
122122
])
123123

124-
call random_number(input_data)
125124
do n = 1, 10000
126-
output_data = net % predict(input_data)
125+
126+
call random_number(input_data)
127+
call net % forward(input_data)
128+
127129
! Check that sum of output matches sum of input within small tolerance
128-
if (abs(sum(output_data) - sum(input_data)) > 1e-6) then
129-
ok = .false.
130-
exit
131-
end if
130+
select type(layer1_p => net % layers(2) % p)
131+
type is(dropout_layer)
132+
if (abs(sum(layer1_p % output) - sum(input_data)) > 1e-6) then
133+
ok = .false.
134+
exit
135+
end if
136+
end select
137+
132138
end do
133-
if (.not. ok) then
134-
write(stderr, '(a)') 'dropout layer output sum should match input sum within tolerance.. failed'
135-
end if
139+
140+
if (.not. ok) write(stderr, '(a)') &
141+
'dropout layer output sum should match input sum within tolerance.. failed'
142+
136143
end block forward_pass
137144

138145

139146
training: block
140-
real :: x(100), y(5)
141-
real :: tolerance = 1e-3
147+
real :: x(20), y(5)
148+
real :: tolerance = 1e-4
142149
integer :: n
143-
integer, parameter :: num_iterations = 10000
150+
integer, parameter :: num_iterations = 100000
144151

145152
call random_number(x)
146153
y = [0.12345, 0.23456, 0.34567, 0.45678, 0.56789]
147154

148155
net = network([ &
149-
input(100), &
150-
dropout(0.5), &
156+
input(20), &
157+
dense(20), &
158+
dropout(0.2), &
151159
dense(5) &
152160
])
153161

154162
do n = 1, num_iterations
155163
call net % forward(x)
156164
call net % backward(y)
157165
call net % update()
158-
159166
if (all(abs(net % predict(x) - y) < tolerance)) exit
160-
161167
end do
162168

163169
if (.not. n <= num_iterations) then

0 commit comments

Comments
 (0)