Skip to content

Commit be36d93

Browse files
committed
multihead_attention: add cross attention
1 parent eef37e3 commit be36d93

File tree

2 files changed

+124
-0
lines changed

2 files changed

+124
-0
lines changed
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
module nf_cross_attention_layer
2+
use iso_fortran_env, only: stderr => error_unit
3+
use nf_activation, only: softmax
4+
use nf_linear2d_layer, only: linear2d_layer
5+
use nf_multihead_attention_layer, only: multihead_attention_layer
6+
7+
implicit none
8+
9+
type, extends(multihead_attention_layer) :: cross_attention_layer
10+
real, allocatable :: gradient(:, :, :)
11+
contains
12+
procedure :: forward
13+
procedure :: backward
14+
procedure :: init
15+
end type cross_attention_layer
16+
17+
interface cross_attention_layer
18+
module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
19+
!! This function returns the `cross_attention_layer` instance.
20+
integer, intent(in) :: sequence_length, model_dimension, n_heads
21+
type(cross_attention_layer) :: res
22+
end function cross_attention_layer_cons
23+
end interface cross_attention_layer
24+
25+
contains
26+
module function cross_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
27+
!! This function returns the `cross_attention_layer` instance.
28+
integer, intent(in) :: sequence_length, model_dimension, n_heads
29+
type(cross_attention_layer) :: res
30+
res % sequence_length = sequence_length
31+
res % model_dimension = model_dimension
32+
res % n_heads = n_heads
33+
34+
if (mod(model_dimension, n_heads) /= 0) then
35+
write(stderr, '(a)'), 'Number of heads must be divisible by model dimension'
36+
error stop
37+
end if
38+
res % head_size = model_dimension / n_heads
39+
40+
res % query_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
41+
res % key_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
42+
res % value_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
43+
res % output_layer = linear2d_layer(sequence_length, model_dimension, model_dimension)
44+
call res % query_layer % init([0])
45+
call res % key_layer % init([0])
46+
call res % value_layer % init([0])
47+
call res % output_layer % init([0])
48+
49+
res % softmax_func = softmax()
50+
end function cross_attention_layer_cons
51+
52+
module subroutine backward(self, input, gradient)
53+
class(cross_attention_layer), intent(in out) :: self
54+
real, intent(in) :: input(:, :, :)
55+
real, intent(in) :: gradient(:, :)
56+
57+
call self % common_backward(input(1, :, :), gradient)
58+
self % gradient(1, :, :) = self % query_layer % gradient
59+
self % gradient(2, :, :) = self % key_layer % gradient + self % value_layer % gradient
60+
end subroutine backward
61+
62+
module subroutine forward(self, input)
63+
class(cross_attention_layer), intent(in out) :: self
64+
real, intent(in) :: input(:, :, :)
65+
66+
call self % common_forward(input(1, :, :), input(2, :, :), input(2, :, :))
67+
end subroutine forward
68+
69+
module subroutine init(self, input_shape)
70+
class(cross_attention_layer), intent(in out) :: self
71+
integer, intent(in) :: input_shape(:)
72+
73+
call self % init_base(input_shape)
74+
allocate(self % gradient(2, self % sequence_length, self % model_dimension))
75+
end subroutine init
76+
end module nf_cross_attention_layer

test/test_multihead_attention_layer.f90

Lines changed: 48 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ program test_multihead_attention_layer
22
use iso_fortran_env, only: stderr => error_unit
33
use nf_multihead_attention_layer, only: multihead_attention_layer
44
use nf_self_attention_layer, only: self_attention_layer
5+
use nf_cross_attention_layer, only: cross_attention_layer
56
use nf_linear2d_layer, only: linear2d_layer
67
use nf_optimizers, only: sgd
78
implicit none
@@ -26,6 +27,7 @@ program test_multihead_attention_layer
2627
call test_multihead_attention_update_gradients(attention, ok)
2728
call test_multihead_attention_forward_reallife_shape(ok)
2829
call test_self_attention(ok)
30+
call test_cross_attention(ok)
2931

3032
contains
3133
subroutine test_multihead_attention_split_heads(attention, input, ok, output)
@@ -318,4 +320,50 @@ subroutine test_self_attention(ok)
318320
write(stderr, '(a)') 'backward returned incorrect values.. failed'
319321
end if
320322
end subroutine test_self_attention
323+
324+
subroutine test_cross_attention(ok)
325+
logical, intent(in out) :: ok
326+
type(cross_attention_layer) :: attention
327+
real :: query(2, 3) = reshape([-1., 0., 17., .4, 5., .6], [2, 3])
328+
real :: key_value(2, 3) = reshape([0.1, -.2, 0.3, 4., 15., 0.5], [2, 3])
329+
real :: input(2, 2, 3)
330+
real :: output(2, 2, 3)
331+
real :: output_flat(6)
332+
real :: expected_output_flat(6) = [&
333+
0.600311756, 0.471662223, 0.600311756, 0.471662223, 0.600311756, 0.471662223&
334+
]
335+
real :: gradient(2, 3) = reshape([1., 2., .17, 4., .5, 6.], [2, 3])
336+
real :: query_gradient_flat(6)
337+
real :: key_value_gradient_flat(6)
338+
real :: expected_query_gradient_flat(6) = [&
339+
1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245, 1.48406753E-03, 0.184446245&
340+
]
341+
real :: expected_key_value_gradient_flat(6) = [&
342+
0.303095698, 0.107004307, 0.303095698, 0.107004307, 0.303095698, 0.107004307&
343+
]
344+
input(1, :, :) = query
345+
input(2, :, :) = key_value
346+
347+
attention = cross_attention_layer(sequence_length=2, model_dimension=3, n_heads=1)
348+
call attention % init([0])
349+
350+
call attention % forward(input)
351+
output_flat = reshape(attention % output, shape(output_flat))
352+
if (.not. all(output_flat.eq.expected_output_flat)) then
353+
ok = .false.
354+
write(stderr, '(a)') 'forward returned incorrect values.. failed'
355+
end if
356+
357+
call attention % backward(input, gradient)
358+
query_gradient_flat = reshape(attention % gradient(1, :, :), shape(query_gradient_flat))
359+
if (.not. all(query_gradient_flat.eq.expected_query_gradient_flat)) then
360+
ok = .false.
361+
write(stderr, '(a)') 'backward returned incorrect query values.. failed'
362+
end if
363+
key_value_gradient_flat = reshape(attention % gradient(2, :, :), shape(key_value_gradient_flat))
364+
if (.not. all(key_value_gradient_flat.eq.expected_key_value_gradient_flat)) then
365+
ok = .false.
366+
write(stderr, '(a)') 'backward returned incorrect key-value values.. failed'
367+
end if
368+
end subroutine test_cross_attention
321369
end program test_multihead_attention_layer

0 commit comments

Comments
 (0)