Skip to content

Commit 0667000

Browse files
committed
layernorm: update tests
1 parent ccc180e commit 0667000

File tree

1 file changed

+14
-6
lines changed

1 file changed

+14
-6
lines changed

test/test_layernorm.f90

Lines changed: 14 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,14 @@ program test_layernorm
2424
end if
2525

2626
contains
27+
function allclose(x, y) result(res)
28+
real, intent(in) :: x(:)
29+
real, intent(in) :: y(:)
30+
logical :: res
31+
32+
res = all(abs(x - y) <= (1e-06 + 1e-05 * abs(y)))
33+
end function allclose
34+
2735
subroutine test_layernorm_forward(layernorm, input, ok)
2836
type(layernorm_layer), intent(in out) :: layernorm
2937
real, intent(in out) :: input(:, :)
@@ -44,7 +52,7 @@ subroutine test_layernorm_forward(layernorm, input, ok)
4452
write(stderr, '(a)') 'forward returned incorrect shape.. failed'
4553
end if
4654
output_flat = reshape(layernorm % output, shape(output_flat))
47-
if (.not. all(output_flat.eq.expected_output_flat)) then
55+
if (.not. allclose(output_flat, expected_output_flat)) then
4856
ok = .false.
4957
write(stderr, '(a)') 'forward returned incorrect values.. failed'
5058
end if
@@ -67,7 +75,7 @@ subroutine test_layernorm_backward(layernorm, input, gradient, ok)
6775
real :: d_gamma(4)
6876
real :: expected_d_gamma(4) = [0.765904069, 0.175162792, 2.16362262, -4.57002449]
6977
real :: d_beta(4)
70-
real :: expected_d_beta(4) = [5.09999990, 6.09999990, 2.19999981, 6.09999990]
78+
real :: expected_d_beta(4) = [5.1, 6.1, 2.2, 6.1]
7179

7280
call layernorm % backward(input, gradient)
7381

@@ -77,16 +85,16 @@ subroutine test_layernorm_backward(layernorm, input, gradient, ok)
7785
write(stderr, '(a)') 'backward returned incorrect gradient shape.. failed'
7886
end if
7987
gradient_flat = reshape(layernorm % gradient, shape(gradient_flat))
80-
if (.not. all(gradient_flat.eq.expected_gradient_flat)) then
88+
if (.not. allclose(gradient_flat, expected_gradient_flat)) then
8189
ok = .false.
8290
write(stderr, '(a)') 'backward returned incorrect gradient values.. failed'
8391
end if
8492

85-
if (.not. all(layernorm % d_gamma.eq.expected_d_gamma)) then
93+
if (.not. allclose(layernorm % d_gamma, expected_d_gamma)) then
8694
ok = .false.
8795
write(stderr, '(a)') 'backward returned incorrect d_gamma values.. failed'
8896
end if
89-
if (.not. all(layernorm % d_beta.eq.expected_d_beta)) then
97+
if (.not. allclose(layernorm % d_beta, expected_d_beta)) then
9098
ok = .false.
9199
write(stderr, '(a)') 'backward returned incorrect d_beta values.. failed'
92100
end if
@@ -135,7 +143,7 @@ subroutine test_layernorm_gradients(input, gradient, ok)
135143
call layernorm % forward(input)
136144

137145
updated_output = reshape(layernorm % output, [12])
138-
if (.not. all(updated_output.eq.expected_updated_output)) then
146+
if (.not. allclose(updated_output, expected_updated_output)) then
139147
ok = .false.
140148
write(stderr, '(a)') 'incorrect output after parameters update.. failed'
141149
end if

0 commit comments

Comments
 (0)