Skip to content

Commit 90d3d6c

Browse files
committed
multihead_attention: tidy mha up
1 parent b6a0915 commit 90d3d6c

File tree

1 file changed

+9
-10
lines changed

1 file changed

+9
-10
lines changed

src/nf/nf_multihead_attention.f90

Lines changed: 9 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -12,8 +12,7 @@ module nf_multihead_attention_layer
1212
type, extends(base_layer) :: multihead_attention_layer
1313

1414
!! Concrete implementation of a multihead attention layer type
15-
16-
integer :: batch_size, sequence_length, model_dimension, n_heads, head_size
15+
integer :: sequence_length, model_dimension, n_heads, head_size
1716

1817
type(linear2d_layer) :: query_layer
1918
type(linear2d_layer) :: key_layer
@@ -45,14 +44,14 @@ module nf_multihead_attention_layer
4544
procedure :: get_params
4645
procedure :: get_gradients
4746
procedure :: set_params
48-
procedure :: init
49-
47+
procedure :: init_base
48+
procedure :: init => init_base ! in case general MHA needs to be used
5049
end type multihead_attention_layer
5150

5251
interface multihead_attention_layer
53-
module function multihead_attention_layer_cons(batch_size, sequence_length, model_dimension, n_heads) result(res)
52+
module function multihead_attention_layer_cons(sequence_length, model_dimension, n_heads) result(res)
5453
!! This function returns the `multihead_attention_layer` instance.
55-
integer, intent(in) :: batch_size, sequence_length, model_dimension, n_heads
54+
integer, intent(in) :: sequence_length, model_dimension, n_heads
5655
type(multihead_attention_layer) :: res
5756
end function multihead_attention_layer_cons
5857
end interface multihead_attention_layer
@@ -270,7 +269,7 @@ end function split_heads
270269

271270
module subroutine create_attention_matrix(self, query, key)
272271
!! Create attention matrix for query and key
273-
!! Output dimensions: n_heads, sequence_length, sequence_length, batch_size
272+
!! Output dimensions: sequence_length, sequence_length, n_heads
274273
class(multihead_attention_layer) :: self
275274
real :: query(:, :, :)
276275
real :: key(:, :, :)
@@ -311,7 +310,7 @@ end subroutine normalize_attention_matrix
311310

312311
module subroutine scaled_dot_product_attention(self, value)
313312
!! Create scaled dot product attention
314-
!! Output dims: n_heads, sequence_length, head_size, batch_size
313+
!! Output dims: sequence_length, head_size, n_heads
315314
class(multihead_attention_layer) :: self
316315
real :: value(:, :, :)
317316
integer :: head
@@ -417,7 +416,7 @@ module subroutine set_params(self, params)
417416
self % output_layer % biases = params(i: j)
418417
end subroutine set_params
419418

420-
module subroutine init(self, input_shape)
419+
module subroutine init_base(self, input_shape)
421420
class(multihead_attention_layer), intent(in out) :: self
422421
integer, intent(in) :: input_shape(:)
423422

@@ -431,5 +430,5 @@ module subroutine init(self, input_shape)
431430
allocate(self % k_input(self % sequence_length, self % model_dimension))
432431
allocate(self % v_input(self % sequence_length, self % model_dimension))
433432
allocate(self % o_input(self % sequence_length, self % model_dimension))
434-
end subroutine init
433+
end subroutine init_base
435434
end module nf_multihead_attention_layer

0 commit comments

Comments
 (0)