Skip to content

Commit c406b42

Browse files
authored
Toward merge networks (#231)
* Minimal concatenated input example * Update example of merging 2 networks to feed into a 3rd network * Allow passing gradient to network % backward() to bypass loss function * Add network % get_output() subroutine that returns a pointer to the outputs * Allow getting output pointer for all layers
1 parent 00acae2 commit c406b42

File tree

3 files changed

+196
-57
lines changed

3 files changed

+196
-57
lines changed

example/merge_networks.f90

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,73 @@
1+
program merge_networks
2+
use nf, only: dense, input, network, sgd
3+
use nf_dense_layer, only: dense_layer
4+
implicit none
5+
6+
type(network) :: net1, net2, net3
7+
real, allocatable :: x1(:), x2(:)
8+
real, pointer :: y1(:), y2(:)
9+
real, allocatable :: y(:)
10+
integer, parameter :: num_iterations = 500
11+
integer :: n, nn
12+
integer :: net1_output_size, net2_output_size
13+
14+
x1 = [0.1, 0.3, 0.5]
15+
x2 = [0.2, 0.4]
16+
y = [0.123456, 0.246802, 0.369258, 0.482604, 0.505050, 0.628406, 0.741852]
17+
18+
net1 = network([ &
19+
input(3), &
20+
dense(2), &
21+
dense(3), &
22+
dense(2) &
23+
])
24+
25+
net2 = network([ &
26+
input(2), &
27+
dense(5), &
28+
dense(3) &
29+
])
30+
31+
net1_output_size = product(net1 % layers(size(net1 % layers)) % layer_shape)
32+
net2_output_size = product(net2 % layers(size(net2 % layers)) % layer_shape)
33+
34+
! Network 3
35+
net3 = network([ &
36+
input(net1_output_size + net2_output_size), &
37+
dense(7) &
38+
])
39+
40+
do n = 1, num_iterations
41+
42+
! Forward propagate two network branches
43+
call net1 % forward(x1)
44+
call net2 % forward(x2)
45+
46+
! Get outputs of net1 and net2, concatenate, and pass to net3
47+
call net1 % get_output(y1)
48+
call net2 % get_output(y2)
49+
call net3 % forward([y1, y2])
50+
51+
! First compute the gradients on net3, then pass the gradients from the first
52+
! hidden layer on net3 to net1 and net2, and compute their gradients.
53+
call net3 % backward(y)
54+
55+
select type (next_layer => net3 % layers(2) % p)
56+
type is (dense_layer)
57+
call net1 % backward(y, gradient=next_layer % gradient(1:net1_output_size))
58+
call net2 % backward(y, gradient=next_layer % gradient(net1_output_size+1:size(next_layer % gradient)))
59+
end select
60+
61+
! Gradients are now computed on all networks and we can update the weights
62+
call net1 % update(optimizer=sgd(learning_rate=1.))
63+
call net2 % update(optimizer=sgd(learning_rate=1.))
64+
call net3 % update(optimizer=sgd(learning_rate=1.))
65+
66+
if (mod(n, 50) == 0) then
67+
print *, "Iteration ", n, ", output RMSE = ", &
68+
sqrt(sum((net3 % predict([net1 % predict(x1), net2 % predict(x2)]) - y)**2) / size(y))
69+
end if
70+
71+
end do
72+
73+
end program merge_networks

src/nf/nf_network.f90

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,7 @@ module nf_network
3333
procedure, private :: forward_1d_int
3434
procedure, private :: forward_2d
3535
procedure, private :: forward_3d
36+
procedure, private :: get_output_1d
3637
procedure, private :: predict_1d
3738
procedure, private :: predict_1d_int
3839
procedure, private :: predict_2d
@@ -42,6 +43,7 @@ module nf_network
4243

4344
generic :: evaluate => evaluate_batch_1d
4445
generic :: forward => forward_1d, forward_1d_int, forward_2d, forward_3d
46+
generic :: get_output => get_output_1d
4547
generic :: predict => predict_1d, predict_1d_int, predict_2d, predict_3d
4648
generic :: predict_batch => predict_batch_1d, predict_batch_3d
4749

@@ -131,7 +133,7 @@ end subroutine forward_3d
131133

132134
end interface forward
133135

134-
interface output
136+
interface predict
135137

136138
module function predict_1d(self, input) result(res)
137139
!! Return the output of the network given the input 1-d array.
@@ -169,9 +171,10 @@ module function predict_3d(self, input) result(res)
169171
real, allocatable :: res(:)
170172
!! Output of the network
171173
end function predict_3d
172-
end interface output
173174

174-
interface output_batch
175+
end interface predict
176+
177+
interface predict_batch
175178
module function predict_batch_1d(self, input) result(res)
176179
!! Return the output of the network given an input batch of 3-d data.
177180
class(network), intent(in out) :: self
@@ -191,11 +194,18 @@ module function predict_batch_3d(self, input) result(res)
191194
real, allocatable :: res(:,:)
192195
!! Output of the network; the last dimension is the batch
193196
end function predict_batch_3d
194-
end interface output_batch
197+
end interface predict_batch
198+
199+
interface get_output
200+
module subroutine get_output_1d(self, output)
201+
class(network), intent(in), target :: self
202+
real, pointer, intent(out) :: output(:)
203+
end subroutine get_output_1d
204+
end interface get_output
195205

196206
interface
197207

198-
module subroutine backward(self, output, loss)
208+
module subroutine backward(self, output, loss, gradient)
199209
!! Apply one backward pass through the network.
200210
!! This changes the state of layers on the network.
201211
!! Typically used only internally from the `train` method,
@@ -206,6 +216,12 @@ module subroutine backward(self, output, loss)
206216
!! Output data
207217
class(loss_type), intent(in), optional :: loss
208218
!! Loss instance to use. If not provided, the default is quadratic().
219+
real, intent(in), optional :: gradient(:)
220+
!! Gradient to use for the output layer.
221+
!! If not provided, the gradient in the last layer is computed using
222+
!! the loss function.
223+
!! Passing the gradient is useful for merging/concatenating multiple
224+
!! networks.
209225
end subroutine backward
210226

211227
module integer function get_num_params(self)

src/nf/nf_network_submodule.f90

Lines changed: 102 additions & 52 deletions
Original file line numberDiff line numberDiff line change
@@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res)
115115
end function network_from_layers
116116

117117

118-
module subroutine backward(self, output, loss)
118+
module subroutine backward(self, output, loss, gradient)
119119
class(network), intent(in out) :: self
120120
real, intent(in) :: output(:)
121121
class(loss_type), intent(in), optional :: loss
122+
real, intent(in), optional :: gradient(:)
122123
integer :: n, num_layers
123124

124125
! Passing the loss instance is optional. If not provided, and if the
@@ -140,58 +141,71 @@ module subroutine backward(self, output, loss)
140141

141142
! Iterate backward over layers, from the output layer
142143
! to the first non-input layer
143-
do n = num_layers, 2, -1
144-
145-
if (n == num_layers) then
146-
! Output layer; apply the loss function
147-
select type(this_layer => self % layers(n) % p)
148-
type is(dense_layer)
149-
call self % layers(n) % backward( &
150-
self % layers(n - 1), &
151-
self % loss % derivative(output, this_layer % output) &
152-
)
153-
type is(flatten_layer)
154-
call self % layers(n) % backward( &
155-
self % layers(n - 1), &
156-
self % loss % derivative(output, this_layer % output) &
157-
)
158-
end select
159-
else
160-
! Hidden layer; take the gradient from the next layer
161-
select type(next_layer => self % layers(n + 1) % p)
162-
type is(dense_layer)
163-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
164-
type is(dropout_layer)
165-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
166-
type is(conv2d_layer)
167-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
168-
type is(flatten_layer)
169-
if (size(self % layers(n) % layer_shape) == 2) then
170-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
171-
else
172-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
173-
end if
174-
type is(maxpool2d_layer)
175-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
176-
type is(reshape3d_layer)
177-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
178-
type is(linear2d_layer)
179-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
180-
type is(self_attention_layer)
181-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
182-
type is(maxpool1d_layer)
183-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
184-
type is(reshape2d_layer)
185-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
186-
type is(conv1d_layer)
187-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
188-
type is(locally_connected2d_layer)
189-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
190-
type is(layernorm_layer)
191-
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
192-
end select
193-
end if
194144

145+
! Output layer first
146+
n = num_layers
147+
if (present(gradient)) then
148+
149+
! If the gradient is passed, use it directly for the output layer
150+
select type(this_layer => self % layers(n) % p)
151+
type is(dense_layer)
152+
call self % layers(n) % backward(self % layers(n - 1), gradient)
153+
type is(flatten_layer)
154+
call self % layers(n) % backward(self % layers(n - 1), gradient)
155+
end select
156+
157+
else
158+
159+
! Apply the loss function
160+
select type(this_layer => self % layers(n) % p)
161+
type is(dense_layer)
162+
call self % layers(n) % backward( &
163+
self % layers(n - 1), &
164+
self % loss % derivative(output, this_layer % output) &
165+
)
166+
type is(flatten_layer)
167+
call self % layers(n) % backward( &
168+
self % layers(n - 1), &
169+
self % loss % derivative(output, this_layer % output) &
170+
)
171+
end select
172+
173+
end if
174+
175+
! Hidden layers; take the gradient from the next layer
176+
do n = num_layers - 1, 2, -1
177+
select type(next_layer => self % layers(n + 1) % p)
178+
type is(dense_layer)
179+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
180+
type is(dropout_layer)
181+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
182+
type is(conv2d_layer)
183+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
184+
type is(flatten_layer)
185+
if (size(self % layers(n) % layer_shape) == 2) then
186+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_2d)
187+
else
188+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient_3d)
189+
end if
190+
type is(maxpool2d_layer)
191+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
192+
type is(reshape3d_layer)
193+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
194+
type is(linear2d_layer)
195+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
196+
type is(self_attention_layer)
197+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
198+
type is(maxpool1d_layer)
199+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
200+
type is(reshape2d_layer)
201+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
202+
type is(conv1d_layer)
203+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
204+
type is(locally_connected2d_layer)
205+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
206+
type is(layernorm_layer)
207+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
208+
end select
195209
end do
196210

197211
end subroutine backward
@@ -497,6 +511,42 @@ module subroutine print_info(self)
497511
end subroutine print_info
498512

499513

514+
module subroutine get_output_1d(self, output)
515+
class(network), intent(in), target :: self
516+
real, pointer, intent(out) :: output(:)
517+
integer :: last
518+
519+
last = size(self % layers)
520+
521+
select type(output_layer => self % layers(last) % p)
522+
type is (conv1d_layer)
523+
output(1:size(output_layer % output)) => output_layer % output
524+
type is(conv2d_layer)
525+
output(1:size(output_layer % output)) => output_layer % output
526+
type is (dense_layer)
527+
output => output_layer % output
528+
type is (dropout_layer)
529+
output => output_layer % output
530+
type is (flatten_layer)
531+
output => output_layer % output
532+
type is (layernorm_layer)
533+
output(1:size(output_layer % output)) => output_layer % output
534+
type is (linear2d_layer)
535+
output(1:size(output_layer % output)) => output_layer % output
536+
type is (locally_connected2d_layer)
537+
output(1:size(output_layer % output)) => output_layer % output
538+
type is (maxpool1d_layer)
539+
output(1:size(output_layer % output)) => output_layer % output
540+
type is (maxpool2d_layer)
541+
output(1:size(output_layer % output)) => output_layer % output
542+
class default
543+
error stop 'network % get_output not implemented for ' // &
544+
trim(self % layers(last) % name) // ' layer'
545+
end select
546+
547+
end subroutine get_output_1d
548+
549+
500550
module function get_num_params(self)
501551
class(network), intent(in) :: self
502552
integer :: get_num_params

0 commit comments

Comments
 (0)