Skip to content

Commit bcda13d

Browse files
committed
multihead_attention: fix incorrect dw bug
1 parent 05842ce commit bcda13d

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@ module nf_multihead_attention_layer
3131
real, allocatable :: q_input(:, :)
3232
real, allocatable :: k_input(:, :)
3333
real, allocatable :: v_input(:, :)
34+
real, allocatable :: o_input(:, :)
3435
contains
3536

3637
procedure :: backward
@@ -146,7 +147,7 @@ module subroutine backward(self, input, gradient)
146147
! calculate output layer delta
147148
! FIXME: remove reshapes when linear2d situation is resolved
148149
call self % output_layer % backward(&
149-
reshape(input, [self % sequence_length, self % model_dimension, 1]),&
150+
reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]),&
150151
reshape(gradient, [self % sequence_length, self % model_dimension, 1])&
151152
)
152153

@@ -265,8 +266,8 @@ module subroutine forward(self, query, key, value)
265266
call self % scaled_dot_product_attention(v)
266267

267268
! FIXME: remove reshapes when linear2d situation is resolved
268-
call self % output_layer % forward(&
269-
reshape(self % combine_heads(self % sdpa), [self % sequence_length, self % model_dimension, 1]))
269+
self % o_input = self % combine_heads(self % sdpa)
270+
call self % output_layer % forward(reshape(self % o_input, [self % sequence_length, self % model_dimension, 1]))
270271
self % output = reshape(self % output_layer % output, [self % sequence_length, self % model_dimension])
271272

272273
! free temp vars from memory
@@ -449,5 +450,6 @@ module subroutine init(self, input_shape)
449450
allocate(self % q_input(self % sequence_length, self % model_dimension))
450451
allocate(self % k_input(self % sequence_length, self % model_dimension))
451452
allocate(self % v_input(self % sequence_length, self % model_dimension))
453+
allocate(self % o_input(self % sequence_length, self % model_dimension))
452454
end subroutine init
453455
end module nf_multihead_attention_layer

0 commit comments

Comments
 (0)