Skip to content

Commit 38c998f

Browse files
committed
Minimal concatenated input example
1 parent 89a255b commit 38c998f

File tree

1 file changed

+63
-0
lines changed

1 file changed

+63
-0
lines changed

example/concatenate.f90

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,63 @@
1+
program concatenate
2+
use nf, only: dense, input, network, sgd
3+
implicit none
4+
5+
type(network) :: net1, net2, net3
6+
real, allocatable :: x1(:), y1(:)
7+
real, allocatable :: x2(:), y2(:)
8+
real, allocatable :: x3(:), y3(:)
9+
integer, parameter :: num_iterations = 500
10+
integer :: n
11+
12+
! Network 1
13+
net1 = network([ &
14+
input(3), &
15+
dense(2) &
16+
])
17+
18+
x1 = [0.2, 0.4, 0.6]
19+
y1 = [0.123456, 0.246802]
20+
21+
do n = 1, num_iterations
22+
call net1 % forward(x1)
23+
call net1 % backward(y1)
24+
call net1 % update(optimizer=sgd(learning_rate=1.))
25+
end do
26+
27+
print *, "net1 output: ", net1 % predict(x1)
28+
29+
! Network 2
30+
net2 = network([ &
31+
input(3), &
32+
dense(3) &
33+
])
34+
35+
x2 = [0.7, 0.5, 0.3]
36+
y2 = [0.369258, 0.482604, 0.505050]
37+
38+
do n = 1, num_iterations
39+
call net2 % forward(x2)
40+
call net2 % backward(y2)
41+
call net2 % update(optimizer=sgd(learning_rate=1.))
42+
end do
43+
44+
print *, "net2 output: ", net2 % predict(x2)
45+
46+
! Network 3
47+
net3 = network([ &
48+
input(size(net1 % predict(x1)) + size(net2 % predict(x2))), &
49+
dense(5) &
50+
])
51+
52+
x3 = [net1 % predict(x1), net2 % predict(x2)]
53+
y3 = [0.111111, 0.222222, 0.333333, 0.444444, 0.555555]
54+
55+
do n = 1, num_iterations
56+
call net3 % forward(x3)
57+
call net3 % backward(y3)
58+
call net3 % update(optimizer=sgd(learning_rate=1.))
59+
end do
60+
61+
print *, "net3 output: ", net3 % predict(x3)
62+
63+
end program concatenate

0 commit comments

Comments
 (0)