Skip to content

Commit 5ae7e9d

Browse files
committed
WIP dropout tests
1 parent e9772a0 commit 5ae7e9d

File tree

1 file changed

+39
-1
lines changed

1 file changed

+39
-1
lines changed

test/test_dropout_layer.f90

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
program test_dropout_layer
22
use iso_fortran_env, only: stderr => error_unit
3-
use nf, only: dropout, input, layer, network
3+
use nf, only: dense, dropout, input, layer, network
44
use nf_dropout_layer, only: dropout_layer
55
type(layer) :: layer1
66
type(network) :: net
@@ -120,6 +120,44 @@ program test_dropout_layer
120120
end if
121121
end block forward_pass
122122

123+
124+
training: block
125+
real :: x(10), y(5)
126+
real :: tolerance = 1e-3
127+
integer :: n
128+
integer, parameter :: num_iterations = 100000
129+
130+
call random_number(x)
131+
y = [0.1234, 0.2345, 0.3456, 0.4567, 0.5678]
132+
133+
net = network([ &
134+
input(10), &
135+
dropout(0.5, training=.true.), &
136+
dense(5) &
137+
])
138+
139+
do n = 1, num_iterations
140+
!select type(dropout_l => net % layers(2) % p)
141+
! type is(dropout_layer)
142+
! print *, dropout_l % training, dropout_l % mask
143+
!end select
144+
call net % forward(x)
145+
call net % backward(y)
146+
call net % update()
147+
!print *, n, net % predict(x)
148+
149+
if (all(abs(net % predict(x) - y) < tolerance)) exit
150+
end do
151+
152+
if (.not. n <= num_iterations) then
153+
write(stderr, '(a)') &
154+
'dense network should converge in simple training.. failed'
155+
ok = .false.
156+
end if
157+
158+
end block training
159+
160+
123161
if (ok) then
124162
print '(a)', 'test_dropout_layer: All tests passed.'
125163
else

0 commit comments

Comments
 (0)