11program cnn_mnist_1d
22
33 use nf, only: network, sgd, &
4- input, conv1d, conv2d, maxpool1d, maxpool2d , flatten, dense, reshape, reshape2d, locally_connected_1d, &
4+ input, conv1d, maxpool1d, flatten, dense, reshape, reshape2d, locally_connected_1d, &
55 load_mnist, label_digits, softmax, relu
66
77 implicit none
@@ -12,7 +12,7 @@ program cnn_mnist_1d
1212 real , allocatable :: validation_images(:,:), validation_labels(:)
1313 real , allocatable :: testing_images(:,:), testing_labels(:)
1414 integer :: n
15- integer , parameter :: num_epochs = 25
15+ integer , parameter :: num_epochs = 250
1616
1717 call load_mnist(training_images, training_labels, &
1818 validation_images, validation_labels, &
@@ -37,7 +37,7 @@ program cnn_mnist_1d
3737 label_digits(training_labels), &
3838 batch_size= 16 , &
3939 epochs= 1 , &
40- optimizer= sgd(learning_rate= 0.005 ) &
40+ optimizer= sgd(learning_rate= 0.01 ) &
4141 )
4242
4343 print ' (a,i2,a,f5.2,a)' , ' Epoch ' , n, ' done, Accuracy: ' , accuracy( &
0 commit comments