@@ -24,6 +24,14 @@ program test_layernorm
2424 end if
2525
2626contains
27+ function allclose (x , y ) result(res)
28+ real , intent (in ) :: x(:)
29+ real , intent (in ) :: y(:)
30+ logical :: res
31+
32+ res = all (abs (x - y) <= (1e-06 + 1e-05 * abs (y)))
33+ end function allclose
34+
2735 subroutine test_layernorm_forward (layernorm , input , ok )
2836 type (layernorm_layer), intent (in out ) :: layernorm
2937 real , intent (in out ) :: input(:, :)
@@ -44,7 +52,7 @@ subroutine test_layernorm_forward(layernorm, input, ok)
4452 write (stderr, ' (a)' ) ' forward returned incorrect shape.. failed'
4553 end if
4654 output_flat = reshape (layernorm % output, shape (output_flat))
47- if (.not. all (output_flat.eq. expected_output_flat)) then
55+ if (.not. allclose (output_flat, expected_output_flat)) then
4856 ok = .false.
4957 write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
5058 end if
@@ -67,7 +75,7 @@ subroutine test_layernorm_backward(layernorm, input, gradient, ok)
6775 real :: d_gamma(4 )
6876 real :: expected_d_gamma(4 ) = [0.765904069 , 0.175162792 , 2.16362262 , - 4.57002449 ]
6977 real :: d_beta(4 )
70- real :: expected_d_beta(4 ) = [5.09999990 , 6.09999990 , 2.19999981 , 6.09999990 ]
78+ real :: expected_d_beta(4 ) = [5.1 , 6.1 , 2.2 , 6.1 ]
7179
7280 call layernorm % backward(input, gradient)
7381
@@ -77,16 +85,16 @@ subroutine test_layernorm_backward(layernorm, input, gradient, ok)
7785 write (stderr, ' (a)' ) ' backward returned incorrect gradient shape.. failed'
7886 end if
7987 gradient_flat = reshape (layernorm % gradient, shape (gradient_flat))
80- if (.not. all (gradient_flat.eq. expected_gradient_flat)) then
88+ if (.not. allclose (gradient_flat, expected_gradient_flat)) then
8189 ok = .false.
8290 write (stderr, ' (a)' ) ' backward returned incorrect gradient values.. failed'
8391 end if
8492
85- if (.not. all (layernorm % d_gamma.eq. expected_d_gamma)) then
93+ if (.not. allclose (layernorm % d_gamma, expected_d_gamma)) then
8694 ok = .false.
8795 write (stderr, ' (a)' ) ' backward returned incorrect d_gamma values.. failed'
8896 end if
89- if (.not. all (layernorm % d_beta.eq. expected_d_beta)) then
97+ if (.not. allclose (layernorm % d_beta, expected_d_beta)) then
9098 ok = .false.
9199 write (stderr, ' (a)' ) ' backward returned incorrect d_beta values.. failed'
92100 end if
@@ -135,7 +143,7 @@ subroutine test_layernorm_gradients(input, gradient, ok)
135143 call layernorm % forward(input)
136144
137145 updated_output = reshape (layernorm % output, [12 ])
138- if (.not. all (updated_output.eq. expected_updated_output)) then
146+ if (.not. allclose (updated_output, expected_updated_output)) then
139147 ok = .false.
140148 write (stderr, ' (a)' ) ' incorrect output after parameters update.. failed'
141149 end if
0 commit comments