|
| 1 | +program test_reshape_layer |
| 2 | + |
| 3 | + use iso_fortran_env, only: stderr => error_unit |
| 4 | + use nf, only: input, network, reshape_generalized ! Check if this is the correct function |
| 5 | + use nf_datasets, only: download_and_unpack, keras_reshape_url |
| 6 | + |
| 7 | + implicit none |
| 8 | + |
| 9 | + type(network) :: net |
| 10 | + real, allocatable :: sample_input(:), output(:,:,:) |
| 11 | + integer, parameter :: output_shape_first(2) = [64, 32] |
| 12 | + integer, parameter :: output_shape_second(6) = [8, 8, 4, 2, 2, 2] |
| 13 | + integer, parameter :: output_shape_third(5) = [4, 4, 4, 4, 8] |
| 14 | + integer :: input_size ! Removed parameter |
| 15 | + character(*), parameter :: keras_reshape_path = 'keras_reshape.h5' |
| 16 | + logical :: ok = .true. |
| 17 | + integer :: i |
| 18 | + integer, dimension(:), allocatable :: output_shape |
| 19 | + |
| 20 | + ! Test multiple reshape configurations |
| 21 | + do i = 1, 3 |
| 22 | + select case (i) |
| 23 | + case (1) |
| 24 | + output_shape = output_shape_first |
| 25 | + case (2) |
| 26 | + output_shape = output_shape_second |
| 27 | + case (3) |
| 28 | + output_shape = output_shape_third |
| 29 | + end select |
| 30 | + |
| 31 | + ! Update input size |
| 32 | + input_size = product(output_shape) |
| 33 | + |
| 34 | + ! Create network with reshape_generalized |
| 35 | + net = network([ & |
| 36 | + input(input_size), & |
| 37 | + reshape_generalized(output_shape) & ! Make sure the function name is correct |
| 38 | + ]) |
| 39 | + |
| 40 | + if (.not. size(net % layers) == 2) then |
| 41 | + write(stderr, '(a, i0)') 'Test case ', i, ': the network should have 2 layers.. failed' |
| 42 | + ok = .false. |
| 43 | + end if |
| 44 | + |
| 45 | + ! Initialize test data |
| 46 | + allocate(sample_input(input_size)) |
| 47 | + call random_number(sample_input) |
| 48 | + |
| 49 | + ! Allocate output correctly before reshaping |
| 50 | + allocate(output(output_shape(1), output_shape(2), output_shape(3))) |
| 51 | + output = reshape(sample_input, shape(output)) |
| 52 | + |
| 53 | + ! Check shape |
| 54 | + if (.not. all(shape(output) == output_shape)) then |
| 55 | + write(stderr, '(a, i0)') 'Test case ', i, ': the reshape layer produces expected output shape.. failed' |
| 56 | + ok = .false. |
| 57 | + end if |
| 58 | + |
| 59 | + ! Check values |
| 60 | + if (.not. all(output == reshape(sample_input, shape(output)))) then |
| 61 | + write(stderr, '(a, i0)') 'Test case ', i, ': the reshape layer produces expected output values.. failed' |
| 62 | + ok = .false. |
| 63 | + end if |
| 64 | + |
| 65 | + ! Deallocate for next test case |
| 66 | + deallocate(sample_input, output) |
| 67 | + end do |
| 68 | + |
| 69 | + ! Final test result |
| 70 | + if (ok) then |
| 71 | + print '(a)', 'test_reshape_generalized_layer: All tests passed.' |
| 72 | + else |
| 73 | + write(stderr, '(a)') 'test_reshape_generalized_layer: One or more tests failed.' |
| 74 | + stop 1 |
| 75 | + end if |
| 76 | + |
| 77 | +end program test_reshape_layer |
0 commit comments