@@ -12,8 +12,7 @@ module nf_multihead_attention_layer
1212 type, extends(base_layer) :: multihead_attention_layer
1313
1414 ! ! Concrete implementation of a multihead attention layer type
15-
16- integer :: batch_size, sequence_length, model_dimension, n_heads, head_size
15+ integer :: sequence_length, model_dimension, n_heads, head_size
1716
1817 type (linear2d_layer) :: query_layer
1918 type (linear2d_layer) :: key_layer
@@ -45,14 +44,14 @@ module nf_multihead_attention_layer
4544 procedure :: get_params
4645 procedure :: get_gradients
4746 procedure :: set_params
48- procedure :: init
49-
47+ procedure :: init_base
48+ procedure :: init = > init_base ! in case general MHA needs to be used
5049 end type multihead_attention_layer
5150
5251 interface multihead_attention_layer
53- module function multihead_attention_layer_cons (batch_size , sequence_length , model_dimension , n_heads ) result(res)
52+ module function multihead_attention_layer_cons (sequence_length , model_dimension , n_heads ) result(res)
5453 ! ! This function returns the `multihead_attention_layer` instance.
55- integer , intent (in ) :: batch_size, sequence_length, model_dimension, n_heads
54+ integer , intent (in ) :: sequence_length, model_dimension, n_heads
5655 type (multihead_attention_layer) :: res
5756 end function multihead_attention_layer_cons
5857 end interface multihead_attention_layer
@@ -270,7 +269,7 @@ end function split_heads
270269
271270 module subroutine create_attention_matrix (self , query , key )
272271 ! ! Create attention matrix for query and key
273- ! ! Output dimensions: n_heads, sequence_length, sequence_length, batch_size
272+ ! ! Output dimensions: sequence_length, sequence_length, n_heads
274273 class(multihead_attention_layer) :: self
275274 real :: query(:, :, :)
276275 real :: key(:, :, :)
@@ -311,7 +310,7 @@ end subroutine normalize_attention_matrix
311310
312311 module subroutine scaled_dot_product_attention (self , value )
313312 ! ! Create scaled dot product attention
314- ! ! Output dims: n_heads, sequence_length, head_size, batch_size
313+ ! ! Output dims: sequence_length, head_size, n_heads
315314 class(multihead_attention_layer) :: self
316315 real :: value(:, :, :)
317316 integer :: head
@@ -417,7 +416,7 @@ module subroutine set_params(self, params)
417416 self % output_layer % biases = params(i: j)
418417 end subroutine set_params
419418
420- module subroutine init (self , input_shape )
419+ module subroutine init_base (self , input_shape )
421420 class(multihead_attention_layer), intent (in out ) :: self
422421 integer , intent (in ) :: input_shape(:)
423422
@@ -431,5 +430,5 @@ module subroutine init(self, input_shape)
431430 allocate (self % k_input(self % sequence_length, self % model_dimension))
432431 allocate (self % v_input(self % sequence_length, self % model_dimension))
433432 allocate (self % o_input(self % sequence_length, self % model_dimension))
434- end subroutine init
433+ end subroutine init_base
435434end module nf_multihead_attention_layer
0 commit comments