Skip to content

Commit a5030cb

Browse files
committed
multihead_attention: add more comments
1 parent be36d93 commit a5030cb

File tree

2 files changed

+19
-0
lines changed

2 files changed

+19
-0
lines changed

src/nf/nf_cross_attention_layer.f90

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ module nf_cross_attention_layer
77
implicit none
88

99
type, extends(multihead_attention_layer) :: cross_attention_layer
10+
!! Cross Attention Layer
11+
!! Source:
12+
!! Bahdanau, D. (2014)
13+
!! Neural machine translation by jointly learning to align and translate.
14+
!! https://arxiv.org/pdf/1409.0473
1015
real, allocatable :: gradient(:, :, :)
1116
contains
1217
procedure :: forward
@@ -50,6 +55,7 @@ module function cross_attention_layer_cons(sequence_length, model_dimension, n_h
5055
end function cross_attention_layer_cons
5156

5257
module subroutine backward(self, input, gradient)
58+
!! Cross Attention Back propagation
5359
class(cross_attention_layer), intent(in out) :: self
5460
real, intent(in) :: input(:, :, :)
5561
real, intent(in) :: gradient(:, :)
@@ -60,6 +66,9 @@ module subroutine backward(self, input, gradient)
6066
end subroutine backward
6167

6268
module subroutine forward(self, input)
69+
!! Cross Attention Forward propagation
70+
!! Input Shape (kind, sequence_length, model_dimension)
71+
!! where kind is 1 for Query and 2 for Key-Value
6372
class(cross_attention_layer), intent(in out) :: self
6473
real, intent(in) :: input(:, :, :)
6574

src/nf/nf_self_attention_layer.f90

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,11 @@ module nf_self_attention_layer
77
implicit none
88

99
type, extends(multihead_attention_layer) :: self_attention_layer
10+
!! Self Attention Layer
11+
!! Source:
12+
!! Parikh, A. P., Taeckstroem, O., Das, D., & Uszkoreit, J. (2016)
13+
!! A decomposable attention model for natural language inference.
14+
!! https://arxiv.org/pdf/1606.01933
1015
real, allocatable :: gradient(:, :)
1116
contains
1217
procedure :: forward
@@ -50,6 +55,8 @@ module function self_attention_layer_cons(sequence_length, model_dimension, n_he
5055
end function self_attention_layer_cons
5156

5257
module subroutine backward(self, input, gradient)
58+
!! Self Attention back propagation
59+
!! Returns sum of Query, Key and Value gradients
5360
class(self_attention_layer), intent(in out) :: self
5461
real, intent(in) :: input(:, :)
5562
real, intent(in) :: gradient(:, :)
@@ -62,6 +69,9 @@ module subroutine backward(self, input, gradient)
6269
end subroutine backward
6370

6471
module subroutine forward(self, input)
72+
!! Cross Attention forward propagation
73+
!! Passes input three times into MultiHead Attention
74+
!! Input Shape: (sequence_length, model_dimension)
6575
class(self_attention_layer), intent(in out) :: self
6676
real, intent(in) :: input(:, :)
6777

0 commit comments

Comments
 (0)