Skip to content

Commit b7a6d06

Browse files
committed
multihead_attention: simple MHA example
1 parent cb717f5 commit b7a6d06

File tree

1 file changed

+37
-0
lines changed

1 file changed

+37
-0
lines changed

example/mha_simple.f90

Lines changed: 37 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,37 @@
1+
program simple
2+
use nf, only: dense, input, network, sgd, self_attention, flatten
3+
implicit none
4+
type(network) :: net
5+
real, allocatable :: x(:, :), y(:)
6+
integer, parameter :: num_iterations = 500
7+
integer :: n
8+
9+
print '("Simple")'
10+
print '(60("="))'
11+
12+
net = network([ &
13+
input(3, 8), &
14+
self_attention(4), &
15+
flatten(), &
16+
dense(2) &
17+
])
18+
19+
call net % print_info()
20+
21+
allocate(x(3, 8))
22+
call random_number(x)
23+
24+
y = [0.123456, 0.246802]
25+
26+
do n = 0, num_iterations
27+
28+
call net % forward(x)
29+
call net % backward(y)
30+
call net % update(optimizer=sgd(learning_rate=1.))
31+
32+
if (mod(n, 50) == 0) &
33+
print '(i4,2(3x,f8.6))', n, net % predict(x)
34+
35+
end do
36+
37+
end program simple

0 commit comments

Comments
 (0)