@@ -31,6 +31,7 @@ module nf_multihead_attention_layer
3131 real , allocatable :: q_input(:, :)
3232 real , allocatable :: k_input(:, :)
3333 real , allocatable :: v_input(:, :)
34+ real , allocatable :: o_input(:, :)
3435 contains
3536
3637 procedure :: backward
@@ -146,7 +147,7 @@ module subroutine backward(self, input, gradient)
146147 ! calculate output layer delta
147148 ! FIXME: remove reshapes when linear2d situation is resolved
148149 call self % output_layer % backward(&
149- reshape (input , [self % sequence_length, self % model_dimension, 1 ]),&
150+ reshape (self % o_input , [self % sequence_length, self % model_dimension, 1 ]),&
150151 reshape (gradient, [self % sequence_length, self % model_dimension, 1 ])&
151152 )
152153
@@ -265,8 +266,8 @@ module subroutine forward(self, query, key, value)
265266 call self % scaled_dot_product_attention(v)
266267
267268 ! FIXME: remove reshapes when linear2d situation is resolved
268- call self % output_layer % forward(&
269- reshape (self % combine_heads( self % sdpa) , [self % sequence_length, self % model_dimension, 1 ]))
269+ self % o_input = self % combine_heads(self % sdpa)
270+ call self % output_layer % forward( reshape ( self % o_input , [self % sequence_length, self % model_dimension, 1 ]))
270271 self % output = reshape (self % output_layer % output, [self % sequence_length, self % model_dimension])
271272
272273 ! free temp vars from memory
@@ -449,5 +450,6 @@ module subroutine init(self, input_shape)
449450 allocate (self % q_input(self % sequence_length, self % model_dimension))
450451 allocate (self % k_input(self % sequence_length, self % model_dimension))
451452 allocate (self % v_input(self % sequence_length, self % model_dimension))
453+ allocate (self % o_input(self % sequence_length, self % model_dimension))
452454 end subroutine init
453455end module nf_multihead_attention_layer
0 commit comments