11program cnn_mnist_1d
22
33 use nf, only: network, sgd, &
4- input, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
4+ input, conv1d, conv2d, maxpool1d, maxpool2d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
55 load_mnist, label_digits, softmax, relu
66
77 implicit none
@@ -12,7 +12,7 @@ program cnn_mnist_1d
1212 real , allocatable :: validation_images(:,:), validation_labels(:)
1313 real , allocatable :: testing_images(:,:), testing_labels(:)
1414 integer :: n
15- integer , parameter :: num_epochs = 10
15+ integer , parameter :: num_epochs = 25
1616
1717 call load_mnist(training_images, training_labels, &
1818 validation_images, validation_labels, &
@@ -21,9 +21,9 @@ program cnn_mnist_1d
2121 net = network([ &
2222 input(784 ), &
2323 reshape2d([28 ,28 ]), &
24- locally_connected_1d (filters= 8 , kernel_size= 3 , activation= relu()), &
24+ conv1d (filters= 8 , kernel_size= 3 , activation= relu()), &
2525 maxpool1d(pool_size= 2 ), &
26- locally_connected_1d (filters= 16 , kernel_size= 3 , activation= relu()), &
26+ conv1d (filters= 16 , kernel_size= 3 , activation= relu()), &
2727 maxpool1d(pool_size= 2 ), &
2828 dense(10 , activation= softmax()) &
2929 ])
@@ -37,7 +37,7 @@ program cnn_mnist_1d
3737 label_digits(training_labels), &
3838 batch_size= 16 , &
3939 epochs= 1 , &
40- optimizer= sgd(learning_rate= 0.003 ) &
40+ optimizer= sgd(learning_rate= 0.005 ) &
4141 )
4242
4343 print ' (a,i2,a,f5.2,a)' , ' Epoch ' , n, ' done, Accuracy: ' , accuracy( &
0 commit comments