@@ -112,7 +112,7 @@ program test_dropout_layer
112112 ! Now we're gonna run the forward pass and check that the dropout indeed
113113 ! drops according to the requested dropout rate.
114114 forward_pass: block
115- real :: input_data(5 )
115+ real :: input_data(4 )
116116 real :: output_data(size (input_data))
117117 integer :: n
118118
@@ -121,43 +121,49 @@ program test_dropout_layer
121121 dropout(0.5 ) &
122122 ])
123123
124- call random_number (input_data)
125124 do n = 1 , 10000
126- output_data = net % predict(input_data)
125+
126+ call random_number (input_data)
127+ call net % forward(input_data)
128+
127129 ! Check that sum of output matches sum of input within small tolerance
128- if (abs (sum (output_data) - sum (input_data)) > 1e-6 ) then
129- ok = .false.
130- exit
131- end if
130+ select type (layer1_p = > net % layers(2 ) % p)
131+ type is (dropout_layer)
132+ if (abs (sum (layer1_p % output) - sum (input_data)) > 1e-6 ) then
133+ ok = .false.
134+ exit
135+ end if
136+ end select
137+
132138 end do
133- if (.not. ok) then
134- write (stderr, ' (a)' ) ' dropout layer output sum should match input sum within tolerance.. failed'
135- end if
139+
140+ if (.not. ok) write (stderr, ' (a)' ) &
141+ ' dropout layer output sum should match input sum within tolerance.. failed'
142+
136143 end block forward_pass
137144
138145
139146 training: block
140- real :: x(100 ), y(5 )
141- real :: tolerance = 1e-3
147+ real :: x(20 ), y(5 )
148+ real :: tolerance = 1e-4
142149 integer :: n
143- integer , parameter :: num_iterations = 10000
150+ integer , parameter :: num_iterations = 100000
144151
145152 call random_number (x)
146153 y = [0.12345 , 0.23456 , 0.34567 , 0.45678 , 0.56789 ]
147154
148155 net = network([ &
149- input(100 ), &
150- dropout(0.5 ), &
156+ input(20 ), &
157+ dense(20 ), &
158+ dropout(0.2 ), &
151159 dense(5 ) &
152160 ])
153161
154162 do n = 1 , num_iterations
155163 call net % forward(x)
156164 call net % backward(y)
157165 call net % update()
158-
159166 if (all (abs (net % predict(x) - y) < tolerance)) exit
160-
161167 end do
162168
163169 if (.not. n <= num_iterations) then
0 commit comments