Skip to content

Commit 11515cf

Browse files
committed
bug fixes; integrating reshape_generalized in environment
1 parent b4e2303 commit 11515cf

File tree

5 files changed

+22
-7
lines changed

5 files changed

+22
-7
lines changed

example/cnn_mnist_1d.f90

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
program cnn_mnist
22

33
use nf, only: network, sgd, &
4-
input, conv2d, maxpool2d, flatten, dense, reshape, locally_connected_1d, &
4+
input, conv2d, maxpool2d, flatten, dense, reshape, reshape_generalized, locally_connected_1d, &
55
load_mnist, label_digits, softmax, relu
66

77
implicit none
@@ -20,8 +20,11 @@ program cnn_mnist
2020

2121
net = network([ &
2222
input(784), &
23-
reshape([1,28,28]), &
24-
locally_connected_1d(filters=8, kernel_size=2, activation=relu()), &
23+
reshape([1, 28, 28]), &
24+
conv2d(filters=8, kernel_size=3, activation=relu()), &
25+
maxpool2d(pool_size=2), &
26+
conv2d(filters=16, kernel_size=3, activation=relu()), &
27+
maxpool2d(pool_size=2), &
2528
dense(10, activation=softmax()) &
2629
])
2730

src/nf/nf_layer_constructors.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -194,7 +194,7 @@ module function reshape(output_shape) result(res)
194194
end function reshape
195195

196196
module function reshape_generalized(output_shape) result(res)
197-
integer, intent(in) :: output_shape
197+
integer, intent(in) :: output_shape(:)
198198
type(layer) :: res
199199

200200
end function reshape_generalized

src/nf/nf_layer_submodule.f90

Lines changed: 13 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,6 +9,7 @@
99
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
1010
use nf_maxpool2d_layer, only: maxpool2d_layer
1111
use nf_reshape_layer, only: reshape3d_layer
12+
use nf_reshape_layer_generalized, only: reshape_generalized_layer
1213
use nf_optimizers, only: optimizer_base_type
1314

1415
contains
@@ -293,6 +294,8 @@ elemental module function get_num_params(self) result(num_params)
293294
num_params = 0
294295
type is (reshape3d_layer)
295296
num_params = 0
297+
type is (reshape_generalized_layer)
298+
num_params = 0
296299
class default
297300
error stop 'Unknown layer type.'
298301
end select
@@ -318,6 +321,8 @@ module function get_params(self) result(params)
318321
! No parameters to get.
319322
type is (reshape3d_layer)
320323
! No parameters to get.
324+
type is (reshape_generalized_layer)
325+
! No parameters to get.
321326
class default
322327
error stop 'Unknown layer type.'
323328
end select
@@ -343,6 +348,8 @@ module function get_gradients(self) result(gradients)
343348
! No gradients to get.
344349
type is (reshape3d_layer)
345350
! No gradients to get.
351+
type is (reshape_generalized_layer)
352+
! No gradients to get.
346353
class default
347354
error stop 'Unknown layer type.'
348355
end select
@@ -399,7 +406,12 @@ module subroutine set_params(self, params)
399406
! No parameters to set.
400407
write(stderr, '(a)') 'Warning: calling set_params() ' &
401408
// 'on a zero-parameter layer; nothing to do.'
402-
409+
410+
type is (reshape_generalized_layer)
411+
! No parameters to set.
412+
write(stderr, '(a)') 'Warning: calling set_params() ' &
413+
// 'on a zero-parameter layer; nothing to do.'
414+
403415
class default
404416
error stop 'Unknown layer type.'
405417
end select

src/nf/nf_network_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
1111
use nf_layer, only: layer
12-
use nf_layer_constructors, only: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape
12+
use nf_layer_constructors, only: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape, reshape_generalized
1313
use nf_loss, only: quadratic
1414
use nf_optimizers, only: optimizer_base_type, sgd
1515
use nf_parallel, only: tile_indices

src/nf/nf_reshape_generalized_submodule.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@ pure module function reshape_layer_cons(output_shape) result(res)
1111
type(reshape_generalized_layer) :: res
1212

1313
! Check if output_shape is scalar (size 1)
14-
if (size(output_shape) == 1) then
14+
if (size(output_shape) == 0) then
1515
allocate(res % output_shape(1))
1616
res % output_shape = output_shape
1717
else

0 commit comments

Comments
 (0)