Skip to content

Commit d37931d

Browse files
committed
update linear2d_layer tests
1 parent 8aa8278 commit d37931d

File tree

1 file changed

+8
-8
lines changed

1 file changed

+8
-8
lines changed

test/test_linear2d_layer.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,7 @@ program test_linear2d_layer
88
[0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2,&
99
0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2, 0.1, 0.2],&
1010
[2, 3, 4]) ! first batch are 0.1, second 0.2
11-
real :: sample_gradient(2, 3, 1) = reshape([2., 2., 2., 3., 3., 3.], [2, 3, 1])
11+
real :: sample_gradient(2, 3, 1) = reshape([2., 3., 2., 3., 2., 3.], [2, 3, 1])
1212
type(linear2d_layer) :: linear
1313

1414
linear = linear2d_layer(batch_size=2, sequence_length=3, in_features=4, out_features=1)
@@ -55,17 +55,17 @@ subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
5555
real :: expected_dw_shape(2) = [4, 1]
5656
real :: expected_db_shape(1) = [1]
5757
real :: expected_gradient_flat(24) = [&
58-
0.200000003, 0.200000003, 0.200000003, 0.300000012,&
59-
0.300000012, 0.300000012, 0.200000003, 0.200000003,&
60-
0.200000003, 0.300000012, 0.300000012, 0.300000012,&
61-
0.200000003, 0.200000003, 0.200000003, 0.300000012,&
62-
0.300000012, 0.300000012, 0.200000003, 0.200000003,&
63-
0.200000003, 0.300000012, 0.300000012, 0.300000012&
58+
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
59+
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
60+
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
61+
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
62+
0.200000003, 0.300000012, 0.200000003, 0.300000012,&
63+
0.200000003, 0.300000012, 0.200000003, 0.300000012&
6464
]
6565
real :: expected_dw_flat(4)
6666
real :: expected_db(1) = [15.0]
6767

68-
expected_dw_flat = 2.29999995
68+
expected_dw_flat = 2.40000010
6969

7070
call linear % backward(input, gradient)
7171

0 commit comments

Comments
 (0)