Skip to content

Commit 5697175

Browse files
committed
Bug fixes; now reshape_generalized works. Added test for reshape_generalized
1 parent 11515cf commit 5697175

File tree

4 files changed

+86
-5
lines changed

4 files changed

+86
-5
lines changed

example/cnn_mnist_1d.f90

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -20,11 +20,8 @@ program cnn_mnist
2020

2121
net = network([ &
2222
input(784), &
23-
reshape([1, 28, 28]), &
24-
conv2d(filters=8, kernel_size=3, activation=relu()), &
25-
maxpool2d(pool_size=2), &
26-
conv2d(filters=16, kernel_size=3, activation=relu()), &
27-
maxpool2d(pool_size=2), &
23+
reshape_generalized([28, 28]), &
24+
locally_connected_1d(filters=8, kernel_size=3, activation=relu()), &
2825
dense(10, activation=softmax()) &
2926
])
3027

src/nf/nf_network_submodule.f90

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@
88
use nf_locally_connected_1d_layer, only: locally_connected_1d_layer
99
use nf_maxpool2d_layer, only: maxpool2d_layer
1010
use nf_reshape_layer, only: reshape3d_layer
11+
use nf_reshape_layer_generalized, only: reshape_generalized_layer
1112
use nf_layer, only: layer
1213
use nf_layer_constructors, only: conv2d, dense, flatten, input, locally_connected_1d, maxpool2d, reshape, reshape_generalized
1314
use nf_loss, only: quadratic
@@ -76,6 +77,9 @@ module function network_from_layers(layers) result(res)
7677
type is(reshape3d_layer)
7778
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
7879
n = n + 1
80+
type is(reshape_generalized_layer)
81+
res % layers = [res % layers(:n-1), flatten(), res % layers(n:)]
82+
n = n + 1
7983
class default
8084
n = n + 1
8185
end select
@@ -143,6 +147,8 @@ module subroutine backward(self, output, loss)
143147
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
144148
type is(reshape3d_layer)
145149
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
150+
type is(reshape_generalized_layer)
151+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
146152
end select
147153
end if
148154

test/CMakeLists.txt

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ foreach(execid
88
flatten_layer
99
insert_flatten
1010
reshape_layer
11+
reshape_generalized_layer
1112
dense_network
1213
get_set_network_params
1314
conv2d_network
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
program test_reshape_layer
2+
3+
use iso_fortran_env, only: stderr => error_unit
4+
use nf, only: input, network, reshape_generalized ! Check if this is the correct function
5+
use nf_datasets, only: download_and_unpack, keras_reshape_url
6+
7+
implicit none
8+
9+
type(network) :: net
10+
real, allocatable :: sample_input(:), output(:,:,:)
11+
integer, parameter :: output_shape_first(2) = [64, 32]
12+
integer, parameter :: output_shape_second(6) = [8, 8, 4, 2, 2, 2]
13+
integer, parameter :: output_shape_third(5) = [4, 4, 4, 4, 8]
14+
integer :: input_size ! Removed parameter
15+
character(*), parameter :: keras_reshape_path = 'keras_reshape.h5'
16+
logical :: ok = .true.
17+
integer :: i
18+
integer, dimension(:), allocatable :: output_shape
19+
20+
! Test multiple reshape configurations
21+
do i = 1, 3
22+
select case (i)
23+
case (1)
24+
output_shape = output_shape_first
25+
case (2)
26+
output_shape = output_shape_second
27+
case (3)
28+
output_shape = output_shape_third
29+
end select
30+
31+
! Update input size
32+
input_size = product(output_shape)
33+
34+
! Create network with reshape_generalized
35+
net = network([ &
36+
input(input_size), &
37+
reshape_generalized(output_shape) & ! Make sure the function name is correct
38+
])
39+
40+
if (.not. size(net % layers) == 2) then
41+
write(stderr, '(a, i0)') 'Test case ', i, ': the network should have 2 layers.. failed'
42+
ok = .false.
43+
end if
44+
45+
! Initialize test data
46+
allocate(sample_input(input_size))
47+
call random_number(sample_input)
48+
49+
! Allocate output correctly before reshaping
50+
allocate(output(output_shape(1), output_shape(2), output_shape(3)))
51+
output = reshape(sample_input, shape(output))
52+
53+
! Check shape
54+
if (.not. all(shape(output) == output_shape)) then
55+
write(stderr, '(a, i0)') 'Test case ', i, ': the reshape layer produces expected output shape.. failed'
56+
ok = .false.
57+
end if
58+
59+
! Check values
60+
if (.not. all(output == reshape(sample_input, shape(output)))) then
61+
write(stderr, '(a, i0)') 'Test case ', i, ': the reshape layer produces expected output values.. failed'
62+
ok = .false.
63+
end if
64+
65+
! Deallocate for next test case
66+
deallocate(sample_input, output)
67+
end do
68+
69+
! Final test result
70+
if (ok) then
71+
print '(a)', 'test_reshape_generalized_layer: All tests passed.'
72+
else
73+
write(stderr, '(a)') 'test_reshape_generalized_layer: One or more tests failed.'
74+
stop 1
75+
end if
76+
77+
end program test_reshape_layer

0 commit comments

Comments
 (0)