@@ -4,14 +4,14 @@ program test_linear2d_layer
44 implicit none
55
66 logical :: ok = .true.
7- real :: sample_input(2 , 3 , 4 ) = reshape (&
8- [0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 ,&
9- 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 , 0.1 , 0.2 ],&
10- [2 , 3 , 4 ]) ! first batch are 0.1, second 0.2
11- real :: sample_gradient(2 , 3 , 1 ) = reshape ([2 ., 3 ., 2 ., 3 ., 2 ., 3 .], [2 , 3 , 1 ])
7+ real :: sample_input(3 , 4 , 2 ) = reshape (&
8+ [0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 , 0.1 ,&
9+ 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 , 0.2 ],&
10+ [3 , 4 , 2 ]) ! first batch are 0.1, second 0.2
11+ real :: sample_gradient(3 , 1 , 2 ) = reshape ([2 ., 2 ., 2 ., 3 ., 3 ., 3 .], [3 , 1 , 2 ])
1212 type (linear2d_layer) :: linear
1313
14- linear = linear2d_layer(batch_size = 2 , sequence_length= 3 , in_features= 4 , out_features= 1 )
14+ linear = linear2d_layer(sequence_length= 3 , in_features= 4 , out_features= 1 , batch_size = 2 )
1515 call linear % init([4 ])
1616
1717 call test_linear2d_layer_forward(linear, ok, sample_input)
@@ -21,11 +21,11 @@ program test_linear2d_layer
2121 subroutine test_linear2d_layer_forward (linear , ok , input )
2222 type (linear2d_layer), intent (in out ) :: linear
2323 logical , intent (in out ) :: ok
24- real , intent (in ) :: input(2 , 3 , 4 )
24+ real , intent (in ) :: input(3 , 4 , 2 )
2525 real :: output_shape(3 )
2626 real :: output_flat(6 )
27- real :: expected_shape(3 ) = [2 , 3 , 1 ]
28- real :: expected_output_flat(6 ) = [0.15 , 0.19 , 0.15 , 0.19 , 0.15 , 0.19 ]
27+ real :: expected_shape(3 ) = [3 , 1 , 2 ]
28+ real :: expected_output_flat(6 ) = [0.15 , 0.15 , 0.15 , 0.19 , 0.19 , 0.19 ]
2929
3030 call linear % forward(input)
3131
@@ -44,23 +44,23 @@ end subroutine test_linear2d_layer_forward
4444 subroutine test_linear2d_layer_backward (linear , ok , input , gradient )
4545 type (linear2d_layer), intent (in out ) :: linear
4646 logical , intent (in out ) :: ok
47- real , intent (in ) :: input(2 , 3 , 4 )
48- real , intent (in ) :: gradient(2 , 3 , 1 )
47+ real , intent (in ) :: input(3 , 4 , 2 )
48+ real , intent (in ) :: gradient(3 , 1 , 2 )
4949 real :: gradient_shape(3 )
5050 real :: dw_shape(2 )
5151 real :: db_shape(1 )
5252 real :: gradient_flat(24 )
5353 real :: dw_flat(4 )
54- real :: expected_gradient_shape(3 ) = [2 , 3 , 4 ]
54+ real :: expected_gradient_shape(3 ) = [3 , 4 , 2 ]
5555 real :: expected_dw_shape(2 ) = [4 , 1 ]
5656 real :: expected_db_shape(1 ) = [1 ]
5757 real :: expected_gradient_flat(24 ) = [&
58- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 ,&
59- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 ,&
60- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 ,&
61- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 ,&
62- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 ,&
63- 0.200000003 , 0.300000012 , 0.200000003 , 0.300000012 &
58+ 0.200000003 , 0.200000003 , 0.200000003 , 0.200000003 ,&
59+ 0.200000003 , 0.200000003 , 0.200000003 , 0.200000003 ,&
60+ 0.200000003 , 0.200000003 , 0.200000003 , 0.200000003 ,&
61+ 0.300000012 , 0.300000012 , 0.300000012 , 0.300000012 ,&
62+ 0.300000012 , 0.300000012 , 0.300000012 , 0.300000012 ,&
63+ 0.300000012 , 0.300000012 , 0.300000012 , 0.300000012 &
6464 ]
6565 real :: expected_dw_flat(4 )
6666 real :: expected_db(1 ) = [15.0 ]
0 commit comments