Skip to content

Commit 0e11f10

Browse files
committed
Define optimizer instance per layer to preserve memory across layers
1 parent 2160f97 commit 0e11f10

File tree

2 files changed

+39
-8
lines changed

2 files changed

+39
-8
lines changed

src/nf/nf_layer.f90

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ module nf_layer
2222
integer, allocatable :: layer_shape(:)
2323
integer, allocatable :: input_layer_shape(:)
2424
logical :: initialized = .false.
25+
class(optimizer_base_type), allocatable :: optimizer
2526

2627
contains
2728

src/nf/nf_network_submodule.f90

Lines changed: 38 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -597,12 +597,26 @@ module subroutine train(self, input_data, output_data, batch_size, &
597597
! If not provided, we default to SGD with its default settings.
598598
if (present(optimizer)) then
599599
self % optimizer = optimizer
600+
601+
do n = 1, size(self % layers)
602+
self % layers(n) % optimizer = optimizer
603+
end do
604+
600605
else
601606
self % optimizer = sgd()
607+
608+
do n = 1, size(self % layers)
609+
self % layers(n) % optimizer = sgd()
610+
end do
611+
602612
end if
603613

604614
call self % optimizer % init(self % get_num_params())
605615

616+
do n = 1, size(self % layers)
617+
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
618+
end do
619+
606620
! Passing the loss instance is optional.
607621
! If not provided, we default to quadratic().
608622
if (present(loss)) then
@@ -662,10 +676,26 @@ module subroutine update(self, optimizer, batch_size)
662676
if (.not. allocated(self % optimizer)) then
663677
if (present(optimizer)) then
664678
self % optimizer = optimizer
679+
680+
do n = 1, size(self % layers)
681+
self % layers(n) % optimizer = optimizer
682+
end do
683+
665684
else
666685
self % optimizer = sgd()
686+
687+
do n = 1, size(self % layers)
688+
self % layers(n) % optimizer = sgd()
689+
end do
690+
667691
end if
692+
668693
call self % optimizer % init(self % get_num_params())
694+
695+
do n = 1, size(self % layers)
696+
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
697+
end do
698+
669699
end if
670700

671701
if (present(batch_size)) then
@@ -699,29 +729,29 @@ module subroutine update(self, optimizer, batch_size)
699729
type is(dense_layer)
700730
call this_layer % get_params_ptr(weights, biases)
701731
call this_layer % get_gradients_ptr(dw, db)
702-
call self % optimizer % minimize(weights, dw / batch_size_)
703-
call self % optimizer % minimize(biases, db / batch_size_)
732+
call self % layers(n) %optimizer % minimize(weights, dw / batch_size_)
733+
call self % layers(n) %optimizer % minimize(biases, db / batch_size_)
704734
this_layer % dw = 0
705735
this_layer % db = 0
706736
type is(conv1d_layer)
707737
call this_layer % get_params_ptr(weights, biases)
708738
call this_layer % get_gradients_ptr(dw, db)
709-
call self % optimizer % minimize(weights, dw / batch_size_)
710-
call self % optimizer % minimize(biases, db / batch_size_)
739+
call self % layers(n) %optimizer % minimize(weights, dw / batch_size_)
740+
call self % layers(n) %optimizer % minimize(biases, db / batch_size_)
711741
this_layer % dw = 0
712742
this_layer % db = 0
713743
type is(conv2d_layer)
714744
call this_layer % get_params_ptr(weights, biases)
715745
call this_layer % get_gradients_ptr(dw, db)
716-
call self % optimizer % minimize(weights, dw / batch_size_)
717-
call self % optimizer % minimize(biases, db / batch_size_)
746+
call self % layers(n) %optimizer % minimize(weights, dw / batch_size_)
747+
call self % layers(n) %optimizer % minimize(biases, db / batch_size_)
718748
this_layer % dw = 0
719749
this_layer % db = 0
720750
type is(locally_connected1d_layer)
721751
call this_layer % get_params_ptr(weights, biases)
722752
call this_layer % get_gradients_ptr(dw, db)
723-
call self % optimizer % minimize(weights, dw / batch_size_)
724-
call self % optimizer % minimize(biases, db / batch_size_)
753+
call self % layers(n) %optimizer % minimize(weights, dw / batch_size_)
754+
call self % layers(n) %optimizer % minimize(biases, db / batch_size_)
725755
this_layer % dw = 0
726756
this_layer % db = 0
727757
end select

0 commit comments

Comments
 (0)