Skip to content

Commit 5b3a4cb

Browse files
committed
update linear2d_layer tests for batch last
1 parent d37931d commit 5b3a4cb

File tree

1 file changed

+18
-18
lines changed

1 file changed

+18
-18
lines changed

test/test_linear2d_layer.f90

Lines changed: 18 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ program test_linear2d_layer
44
implicit none
55

66
logical :: ok = .true.
7-
real :: sample_input(2, 3, 4) = reshape(&
8-
[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2,&
9-
0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2],&
10-
[2, 3, 4]) ! first batch are 0.1, second 0.2
11-
real :: sample_gradient(2, 3, 1) = reshape([2., 3., 2., 3., 2., 3.], [2, 3, 1])
7+
real :: sample_input(3, 4, 2) = reshape(&
8+
[0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1, 0.1,&
9+
0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2, 0.2],&
10+
[3, 4, 2]) ! first batch are 0.1, second 0.2
11+
real :: sample_gradient(3, 1, 2) = reshape([2., 2., 2., 3., 3., 3.], [3, 1, 2])
1212
type(linear2d_layer) :: linear
1313

14-
linear = linear2d_layer(batch_size=2, sequence_length=3, in_features=4, out_features=1)
14+
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=1, batch_size=2)
1515
call linear % init([4])
1616

1717
call test_linear2d_layer_forward(linear, ok, sample_input)
@@ -21,11 +21,11 @@ program test_linear2d_layer
2121
subroutine test_linear2d_layer_forward(linear, ok, input)
2222
type(linear2d_layer), intent(in out) :: linear
2323
logical, intent(in out) :: ok
24-
real, intent(in) :: input(2, 3, 4)
24+
real, intent(in) :: input(3, 4, 2)
2525
real :: output_shape(3)
2626
real :: output_flat(6)
27-
real :: expected_shape(3) = [2, 3, 1]
28-
real :: expected_output_flat(6) = [0.15, 0.19, 0.15, 0.19, 0.15, 0.19]
27+
real :: expected_shape(3) = [3, 1, 2]
28+
real :: expected_output_flat(6) = [0.15, 0.15, 0.15, 0.19, 0.19, 0.19]
2929

3030
call linear % forward(input)
3131

@@ -44,23 +44,23 @@ end subroutine test_linear2d_layer_forward
4444
subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
4545
type(linear2d_layer), intent(in out) :: linear
4646
logical, intent(in out) :: ok
47-
real, intent(in) :: input(2, 3, 4)
48-
real, intent(in) :: gradient(2, 3, 1)
47+
real, intent(in) :: input(3, 4, 2)
48+
real, intent(in) :: gradient(3, 1, 2)
4949
real :: gradient_shape(3)
5050
real :: dw_shape(2)
5151
real :: db_shape(1)
5252
real :: gradient_flat(24)
5353
real :: dw_flat(4)
54-
real :: expected_gradient_shape(3) = [2, 3, 4]
54+
real :: expected_gradient_shape(3) = [3, 4, 2]
5555
real :: expected_dw_shape(2) = [4, 1]
5656
real :: expected_db_shape(1) = [1]
5757
real :: expected_gradient_flat(24) = [&
58-
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
59-
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
60-
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
61-
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
62-
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
63-
0.200000003, 0.300000012, 0.200000003, 0.300000012&
58+
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
59+
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
60+
0.200000003, 0.200000003, 0.200000003, 0.200000003,&
61+
0.300000012, 0.300000012, 0.300000012, 0.300000012,&
62+
0.300000012, 0.300000012, 0.300000012, 0.300000012,&
63+
0.300000012, 0.300000012, 0.300000012, 0.300000012&
6464
]
6565
real :: expected_dw_flat(4)
6666
real :: expected_db(1) = [15.0]

0 commit comments

Comments
 (0)