Skip to content

Commit f0d5ca2

Browse files
committed
Allow getting output pointer for all layers
1 parent 4aea615 commit f0d5ca2

File tree

1 file changed

+17
-3
lines changed

1 file changed

+17
-3
lines changed

src/nf/nf_network_submodule.f90

Lines changed: 17 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -519,12 +519,26 @@ module subroutine get_output_1d(self, output)
519519
last = size(self % layers)
520520

521521
select type(output_layer => self % layers(last) % p)
522-
type is(dense_layer)
522+
type is (conv1d_layer)
523+
output(1:size(output_layer % output)) => output_layer % output
524+
type is(conv2d_layer)
525+
output(1:size(output_layer % output)) => output_layer % output
526+
type is (dense_layer)
523527
output => output_layer % output
524-
type is(dropout_layer)
528+
type is (dropout_layer)
525529
output => output_layer % output
526-
type is(flatten_layer)
530+
type is (flatten_layer)
527531
output => output_layer % output
532+
type is (layernorm_layer)
533+
output(1:size(output_layer % output)) => output_layer % output
534+
type is (linear2d_layer)
535+
output(1:size(output_layer % output)) => output_layer % output
536+
type is (locally_connected2d_layer)
537+
output(1:size(output_layer % output)) => output_layer % output
538+
type is (maxpool1d_layer)
539+
output(1:size(output_layer % output)) => output_layer % output
540+
type is (maxpool2d_layer)
541+
output(1:size(output_layer % output)) => output_layer % output
528542
class default
529543
error stop 'network % get_output not implemented for ' // &
530544
trim(self % layers(last) % name) // ' layer'

0 commit comments

Comments
 (0)