1+ program cnn_mnist_1d
2+
3+ use nf, only: network, sgd, &
4+ input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected1d, &
5+ load_mnist, label_digits, softmax, relu
6+
7+ implicit none
8+
9+ type (network) :: net
10+
11+ real , allocatable :: training_images(:,:), training_labels(:)
12+ real , allocatable :: validation_images(:,:), validation_labels(:)
13+ real , allocatable :: testing_images(:,:), testing_labels(:)
14+ integer :: n
15+ integer , parameter :: num_epochs = 250
16+
17+ call load_mnist(training_images, training_labels, &
18+ validation_images, validation_labels, &
19+ testing_images, testing_labels)
20+
21+ net = network([ &
22+ input(784 ), &
23+ reshape2d([28 , 28 ]), &
24+ locally_connected1d(filters= 8 , kernel_size= 3 , activation= relu()), &
25+ maxpool1d(pool_size= 2 ), &
26+ locally_connected1d(filters= 16 , kernel_size= 3 , activation= relu()), &
27+ maxpool1d(pool_size= 2 ), &
28+ dense(10 , activation= softmax()) &
29+ ])
30+
31+ call net % print_info()
32+
33+ epochs: do n = 1 , num_epochs
34+
35+ call net % train( &
36+ training_images, &
37+ label_digits(training_labels), &
38+ batch_size= 16 , &
39+ epochs= 1 , &
40+ optimizer= sgd(learning_rate= 0.01 ) &
41+ )
42+
43+ print ' (a,i2,a,f5.2,a)' , ' Epoch ' , n, ' done, Accuracy: ' , accuracy( &
44+ net, validation_images, label_digits(validation_labels)) * 100 , ' %'
45+
46+ end do epochs
47+
48+ print ' (a,f5.2,a)' , ' Testing accuracy: ' , &
49+ accuracy(net, testing_images, label_digits(testing_labels)) * 100 , ' %'
50+
51+ contains
52+
53+ real function accuracy (net , x , y )
54+ type (network), intent (in out ) :: net
55+ real , intent (in ) :: x(:,:), y(:,:)
56+ integer :: i, good
57+ good = 0
58+ do i = 1 , size (x, dim= 2 )
59+ if (all (maxloc (net % predict(x(:,i))) == maxloc (y(:,i)))) then
60+ good = good + 1
61+ end if
62+ end do
63+ accuracy = real (good) / size (x, dim= 2 )
64+ end function accuracy
65+
66+ end program cnn_mnist_1d
67+
0 commit comments