11program test_multihead_attention_layer
22 use iso_fortran_env, only: stderr = > error_unit
33 use nf_multihead_attention_layer, only: multihead_attention_layer
4+ use nf_self_attention_layer, only: self_attention_layer
45 use nf_linear2d_layer, only: linear2d_layer
56 use nf_optimizers, only: sgd
67 implicit none
@@ -13,7 +14,7 @@ program test_multihead_attention_layer
1314 real :: output(3 , 2 , 2 )
1415
1516 attention = multihead_attention_layer(sequence_length= 3 , model_dimension= 4 , n_heads= 2 )
16- call attention % init ([0 ])
17+ call attention % init_base ([0 ])
1718
1819 call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
1920 call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
@@ -24,6 +25,7 @@ program test_multihead_attention_layer
2425 call test_multihead_attention_backward(attention, ok)
2526 call test_multihead_attention_update_gradients(attention, ok)
2627 call test_multihead_attention_forward_reallife_shape(ok)
28+ call test_self_attention(ok)
2729
2830contains
2931 subroutine test_multihead_attention_split_heads (attention , input , ok , output )
@@ -139,7 +141,7 @@ subroutine test_multihead_attention_forward(attention, ok)
139141 type (multihead_attention_layer), intent (in out ) :: attention
140142 logical , intent (in out ) :: ok
141143 real :: input(3 , 4 ) = reshape ([0.0 , 10.1 , 0.2 , 10.3 , 0.4 , 10.5 , 0.6 , 10.7 , 10.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 ])
142- real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size )
144+ real :: output(attention % sequence_length, attention % model_dimension)
143145 real :: output_flat(12 )
144146 integer :: output_shape(2 )
145147 integer :: attn_weights_shape(3 )
@@ -194,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
194196 call random_number (input)
195197
196198 attention = multihead_attention_layer(sequence_length= 148 , model_dimension= 512 , n_heads= 8 )
197- call attention % init ([0 ])
199+ call attention % init_base ([0 ])
198200
199201 call attention % common_forward(input, input, input)
200202
@@ -283,4 +285,37 @@ subroutine test_multihead_attention_update_gradients(attention, ok)
283285 write (stderr, ' (a)' ) ' incorrect output after parameters update.. failed'
284286 end if
285287 end subroutine test_multihead_attention_update_gradients
288+
289+ subroutine test_self_attention (ok )
290+ logical , intent (in out ) :: ok
291+ type (self_attention_layer) :: attention
292+ real :: input(2 , 3 ) = reshape ([- 1 ., 0 ., 17 ., .4 , 5 ., .6 ], [2 , 3 ])
293+ real :: output(2 , 3 )
294+ real :: output_flat(6 )
295+ real :: expected_output_flat(6 ) = [&
296+ 0.772716165 , 0.577548742 , 0.772716165 , 0.577548742 , 0.772716165 , 0.577548742 &
297+ ]
298+ real :: gradient(2 , 3 ) = reshape ([1 ., 2 ., .17 , 4 ., .5 , 6 .], [2 , 3 ])
299+ real :: gradient_flat(6 )
300+ real :: expected_gradient_flat(6 ) = [&
301+ 0.350671142 , 0.607403040 , 0.350671142 , 0.607403040 , 0.350671142 , 0.607403040 &
302+ ]
303+
304+ attention = self_attention_layer(sequence_length= 2 , model_dimension= 3 , n_heads= 1 )
305+ call attention % init([0 ])
306+
307+ call attention % forward(input)
308+ output_flat = reshape (attention % output, shape (output_flat))
309+ if (.not. all (output_flat.eq. expected_output_flat)) then
310+ ok = .false.
311+ write (stderr, ' (a)' ) ' forward returned incorrect values.. failed'
312+ end if
313+
314+ call attention % backward(input, gradient)
315+ gradient_flat = reshape (attention % gradient, shape (gradient_flat))
316+ if (.not. all (gradient_flat.eq. expected_gradient_flat)) then
317+ ok = .false.
318+ write (stderr, ' (a)' ) ' backward returned incorrect values.. failed'
319+ end if
320+ end subroutine test_self_attention
286321end program test_multihead_attention_layer
0 commit comments