Skip to content

Commit eef37e3

Browse files
committed
multihead_attention: self attention
1 parent 90d3d6c commit eef37e3

File tree

2 files changed

+116
-3
lines changed

2 files changed

+116
-3
lines changed

src/nf/nf_self_attention_layer.f90

Lines changed: 78 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,78 @@
1+
module nf_self_attention_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_activation, only: softmax
4+
use nf_linear2d_layer, only: linear2d_layer
5+
use nf_multihead_attention_layer, only: multihead_attention_layer
6+
7+
implicit none
8+
9+
type, extends(multihead_attention_layer) :: self_attention_layer
10+
real, allocatable :: gradient(:, :)
11+
contains
12+
procedure :: forward
13+
procedure :: backward
14+
procedure :: init
15+
end type self_attention_layer
16+
17+
interface self_attention_layer
18+
module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
19+
!! This function returns the `self_attention_layer` instance.
20+
integer, intent(in) :: sequence_length, model_dimension, n_heads
21+
type(self_attention_layer) :: res
22+
end function self_attention_layer_cons
23+
end interface self_attention_layer
24+
25+
contains
26+
module function self_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
27+
!! This function returns the `self_attention_layer` instance.
28+
integer, intent(in) :: sequence_length, model_dimension, n_heads
29+
type(self_attention_layer) :: res
30+
res % sequence_length = sequence_length
31+
res % model_dimension = model_dimension
32+
res % n_heads = n_heads
33+
34+
if (mod(model_dimension, n_heads) /= 0) then
35+
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
36+
error stop
37+
end if
38+
res % head_size = model_dimension / n_heads
39+
40+
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
41+
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
42+
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
43+
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
44+
call res % query_layer % init([0])
45+
call res % key_layer % init([0])
46+
call res % value_layer % init([0])
47+
call res % output_layer % init([0])
48+
49+
res % softmax_func = softmax()
50+
end function self_attention_layer_cons
51+
52+
module subroutine backward(self, input, gradient)
53+
class(self_attention_layer), intent(in out) :: self
54+
real, intent(in) :: input(:, :)
55+
real, intent(in) :: gradient(:, :)
56+
57+
call self % common_backward(input, gradient)
58+
self % gradient = &
59+
self % query_layer % gradient &
60+
+ self % key_layer % gradient &
61+
+ self % value_layer % gradient
62+
end subroutine backward
63+
64+
module subroutine forward(self, input)
65+
class(self_attention_layer), intent(in out) :: self
66+
real, intent(in) :: input(:, :)
67+
68+
call self % common_forward(input, input, input)
69+
end subroutine forward
70+
71+
module subroutine init(self, input_shape)
72+
class(self_attention_layer), intent(in out) :: self
73+
integer, intent(in) :: input_shape(:)
74+
75+
call self % init_base(input_shape)
76+
allocate(self % gradient(self % sequence_length, self % model_dimension))
77+
end subroutine init
78+
end module nf_self_attention_layer

test/test_multihead_attention_layer.f90

Lines changed: 38 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
program test_multihead_attention_layer
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_multihead_attention_layer, only: multihead_attention_layer
4+
use nf_self_attention_layer, only: self_attention_layer
45
use nf_linear2d_layer, only: linear2d_layer
56
use nf_optimizers, only: sgd
67
implicit none
@@ -13,7 +14,7 @@ program test_multihead_attention_layer
1314
real :: output(3, 2, 2)
1415

1516
attention = multihead_attention_layer(sequence_length=3, model_dimension=4, n_heads=2)
16-
call attention % init([0])
17+
call attention % init_base([0])
1718

1819
call test_multihead_attention_split_heads(attention, sample_input, ok, split_heads_output)
1920
call test_multihead_attention_create_attention_matrix(attention, split_heads_output, ok)
@@ -24,6 +25,7 @@ program test_multihead_attention_layer
2425
call test_multihead_attention_backward(attention, ok)
2526
call test_multihead_attention_update_gradients(attention, ok)
2627
call test_multihead_attention_forward_reallife_shape(ok)
28+
call test_self_attention(ok)
2729

2830
contains
2931
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
@@ -139,7 +141,7 @@ subroutine test_multihead_attention_forward(attention, ok)
139141
type(multihead_attention_layer), intent(in out) :: attention
140142
logical, intent(in out) :: ok
141143
real :: input(3, 4) = reshape([0.0, 10.1, 0.2, 10.3, 0.4, 10.5, 0.6, 10.7, 10.8, 0.9, 0.11, 0.12], [3, 4])
142-
real :: output(attention % sequence_length, attention % model_dimension, attention % batch_size)
144+
real :: output(attention % sequence_length, attention % model_dimension)
143145
real :: output_flat(12)
144146
integer :: output_shape(2)
145147
integer :: attn_weights_shape(3)
@@ -194,7 +196,7 @@ subroutine test_multihead_attention_forward_reallife_shape(ok)
194196
call random_number(input)
195197

196198
attention = multihead_attention_layer(sequence_length=148, model_dimension=512, n_heads=8)
197-
call attention % init([0])
199+
call attention % init_base([0])
198200

199201
call attention % common_forward(input, input, input)
200202

@@ -283,4 +285,37 @@ subroutine test_multihead_attention_update_gradients(attention, ok)
283285
write(stderr, '(a)') 'incorrect output after parameters update.. failed'
284286
end if
285287
end subroutine test_multihead_attention_update_gradients
288+
289+
subroutine test_self_attention(ok)
290+
logical, intent(in out) :: ok
291+
type(self_attention_layer) :: attention
292+
real :: input(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3])
293+
real :: output(2, 3)
294+
real :: output_flat(6)
295+
real :: expected_output_flat(6) = [&
296+
0.772716165, 0.577548742, 0.772716165, 0.577548742, 0.772716165, 0.577548742&
297+
]
298+
real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3])
299+
real :: gradient_flat(6)
300+
real :: expected_gradient_flat(6) = [&
301+
0.350671142, 0.607403040, 0.350671142, 0.607403040, 0.350671142, 0.607403040&
302+
]
303+
304+
attention = self_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
305+
call attention % init([0])
306+
307+
call attention % forward(input)
308+
output_flat = reshape(attention % output, shape(output_flat))
309+
if (.not. all(output_flat.eq.expected_output_flat)) then
310+
ok = .false.
311+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
312+
end if
313+
314+
call attention % backward(input, gradient)
315+
gradient_flat = reshape(attention % gradient, shape(gradient_flat))
316+
if (.not. all(gradient_flat.eq.expected_gradient_flat)) then
317+
ok = .false.
318+
write(stderr, '(a)') 'backward returned incorrect values.. failed'
319+
end if
320+
end subroutine test_self_attention
286321
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)