Skip to content

Commit a272634

Browse files
committed
Timing info of dropout
1 parent a542e7c commit a272634

File tree

1 file changed

+55
-0
lines changed

1 file changed

+55
-0
lines changed

test/test_dropout_layer.f90

Lines changed: 55 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -177,6 +177,61 @@ program test_dropout_layer
177177

178178
end block training
179179

180+
! The following timing test is not part of the unit tests, but it's a good
181+
! way to see the performance difference between a network with and without
182+
! dropout.
183+
timing: block
184+
integer, parameter :: layer_size = 100
185+
integer, parameter :: num_iterations = 1000
186+
real :: x(layer_size), y(layer_size)
187+
integer :: n
188+
type(network) :: net1, net2
189+
real :: t1, t2
190+
real :: accumulated_time1 = 0
191+
real :: accumulated_time2 = 0
192+
193+
net1 = network([ &
194+
input(layer_size), &
195+
dense(layer_size), &
196+
dense(layer_size) &
197+
])
198+
199+
net2 = network([ &
200+
input(layer_size), &
201+
dense(layer_size), &
202+
dropout(0.5), &
203+
dense(layer_size) &
204+
])
205+
206+
call random_number(y)
207+
208+
! Network without dropout
209+
do n = 1, num_iterations
210+
call random_number(x)
211+
call cpu_time(t1)
212+
call net1 % forward(x)
213+
call net1 % backward(y)
214+
call net1 % update()
215+
call cpu_time(t2)
216+
accumulated_time1 = accumulated_time1 + (t2 - t1)
217+
end do
218+
219+
! Network with dropout
220+
do n = 1, num_iterations
221+
call random_number(x)
222+
call cpu_time(t1)
223+
call net2 % forward(x)
224+
call net2 % backward(y)
225+
call net2 % update()
226+
call cpu_time(t2)
227+
accumulated_time2 = accumulated_time2 + (t2 - t1)
228+
end do
229+
230+
! Uncomment the following prints to see the timing results.
231+
!print '(a, f9.6, a, f9.6, a)', 'No dropout time: ', accumulated_time1, ' seconds'
232+
!print '(a, f9.6, a, f9.6, a)', 'Dropout time: ', accumulated_time2, ' seconds'
233+
234+
end block timing
180235

181236
if (ok) then
182237
print '(a)', 'test_dropout_layer: All tests passed.'

0 commit comments

Comments
 (0)