@@ -115,10 +115,11 @@ module function network_from_layers(layers) result(res)
115115 end function network_from_layers
116116
117117
118- module subroutine backward (self , output , loss )
118+ module subroutine backward (self , output , loss , gradient )
119119 class(network), intent (in out ) :: self
120120 real , intent (in ) :: output(:)
121121 class(loss_type), intent (in ), optional :: loss
122+ real , intent (in ), optional :: gradient(:)
122123 integer :: n, num_layers
123124
124125 ! Passing the loss instance is optional. If not provided, and if the
@@ -140,58 +141,71 @@ module subroutine backward(self, output, loss)
140141
141142 ! Iterate backward over layers, from the output layer
142143 ! to the first non-input layer
143- do n = num_layers, 2 , - 1
144-
145- if (n == num_layers) then
146- ! Output layer; apply the loss function
147- select type (this_layer = > self % layers(n) % p)
148- type is (dense_layer)
149- call self % layers(n) % backward( &
150- self % layers(n - 1 ), &
151- self % loss % derivative(output, this_layer % output) &
152- )
153- type is (flatten_layer)
154- call self % layers(n) % backward( &
155- self % layers(n - 1 ), &
156- self % loss % derivative(output, this_layer % output) &
157- )
158- end select
159- else
160- ! Hidden layer; take the gradient from the next layer
161- select type (next_layer = > self % layers(n + 1 ) % p)
162- type is (dense_layer)
163- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
164- type is (dropout_layer)
165- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
166- type is (conv2d_layer)
167- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
168- type is (flatten_layer)
169- if (size (self % layers(n) % layer_shape) == 2 ) then
170- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_2d)
171- else
172- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_3d)
173- end if
174- type is (maxpool2d_layer)
175- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
176- type is (reshape3d_layer)
177- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
178- type is (linear2d_layer)
179- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
180- type is (self_attention_layer)
181- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
182- type is (maxpool1d_layer)
183- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
184- type is (reshape2d_layer)
185- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
186- type is (conv1d_layer)
187- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
188- type is (locally_connected2d_layer)
189- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
190- type is (layernorm_layer)
191- call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
192- end select
193- end if
194144
145+ ! Output layer first
146+ n = num_layers
147+ if (present (gradient)) then
148+
149+ ! If the gradient is passed, use it directly for the output layer
150+ select type (this_layer = > self % layers(n) % p)
151+ type is (dense_layer)
152+ call self % layers(n) % backward(self % layers(n - 1 ), gradient)
153+ type is (flatten_layer)
154+ call self % layers(n) % backward(self % layers(n - 1 ), gradient)
155+ end select
156+
157+ else
158+
159+ ! Apply the loss function
160+ select type (this_layer = > self % layers(n) % p)
161+ type is (dense_layer)
162+ call self % layers(n) % backward( &
163+ self % layers(n - 1 ), &
164+ self % loss % derivative(output, this_layer % output) &
165+ )
166+ type is (flatten_layer)
167+ call self % layers(n) % backward( &
168+ self % layers(n - 1 ), &
169+ self % loss % derivative(output, this_layer % output) &
170+ )
171+ end select
172+
173+ end if
174+
175+ ! Hidden layers; take the gradient from the next layer
176+ do n = num_layers - 1 , 2 , - 1
177+ select type (next_layer = > self % layers(n + 1 ) % p)
178+ type is (dense_layer)
179+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
180+ type is (dropout_layer)
181+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
182+ type is (conv2d_layer)
183+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
184+ type is (flatten_layer)
185+ if (size (self % layers(n) % layer_shape) == 2 ) then
186+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_2d)
187+ else
188+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient_3d)
189+ end if
190+ type is (maxpool2d_layer)
191+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
192+ type is (reshape3d_layer)
193+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
194+ type is (linear2d_layer)
195+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
196+ type is (self_attention_layer)
197+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
198+ type is (maxpool1d_layer)
199+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
200+ type is (reshape2d_layer)
201+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
202+ type is (conv1d_layer)
203+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
204+ type is (locally_connected2d_layer)
205+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
206+ type is (layernorm_layer)
207+ call self % layers(n) % backward(self % layers(n - 1 ), next_layer % gradient)
208+ end select
195209 end do
196210
197211 end subroutine backward
@@ -497,6 +511,42 @@ module subroutine print_info(self)
497511 end subroutine print_info
498512
499513
514+ module subroutine get_output_1d (self , output )
515+ class(network), intent (in ), target :: self
516+ real , pointer , intent (out ) :: output(:)
517+ integer :: last
518+
519+ last = size (self % layers)
520+
521+ select type (output_layer = > self % layers(last) % p)
522+ type is (conv1d_layer)
523+ output(1 :size (output_layer % output)) = > output_layer % output
524+ type is (conv2d_layer)
525+ output(1 :size (output_layer % output)) = > output_layer % output
526+ type is (dense_layer)
527+ output = > output_layer % output
528+ type is (dropout_layer)
529+ output = > output_layer % output
530+ type is (flatten_layer)
531+ output = > output_layer % output
532+ type is (layernorm_layer)
533+ output(1 :size (output_layer % output)) = > output_layer % output
534+ type is (linear2d_layer)
535+ output(1 :size (output_layer % output)) = > output_layer % output
536+ type is (locally_connected2d_layer)
537+ output(1 :size (output_layer % output)) = > output_layer % output
538+ type is (maxpool1d_layer)
539+ output(1 :size (output_layer % output)) = > output_layer % output
540+ type is (maxpool2d_layer)
541+ output(1 :size (output_layer % output)) = > output_layer % output
542+ class default
543+ error stop ' network % get_output not implemented for ' // &
544+ trim (self % layers(last) % name) // ' layer'
545+ end select
546+
547+ end subroutine get_output_1d
548+
549+
500550 module function get_num_params (self )
501551 class(network), intent (in ) :: self
502552 integer :: get_num_params
0 commit comments