@@ -597,12 +597,26 @@ module subroutine train(self, input_data, output_data, batch_size, &
597597 ! If not provided, we default to SGD with its default settings.
598598 if (present (optimizer)) then
599599 self % optimizer = optimizer
600+
601+ do n = 1 , size (self % layers)
602+ self % layers(n) % optimizer = optimizer
603+ end do
604+
600605 else
601606 self % optimizer = sgd()
607+
608+ do n = 1 , size (self % layers)
609+ self % layers(n) % optimizer = sgd()
610+ end do
611+
602612 end if
603613
604614 call self % optimizer % init(self % get_num_params())
605615
616+ do n = 1 , size (self % layers)
617+ call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
618+ end do
619+
606620 ! Passing the loss instance is optional.
607621 ! If not provided, we default to quadratic().
608622 if (present (loss)) then
@@ -662,10 +676,26 @@ module subroutine update(self, optimizer, batch_size)
662676 if (.not. allocated (self % optimizer)) then
663677 if (present (optimizer)) then
664678 self % optimizer = optimizer
679+
680+ do n = 1 , size (self % layers)
681+ self % layers(n) % optimizer = optimizer
682+ end do
683+
665684 else
666685 self % optimizer = sgd()
686+
687+ do n = 1 , size (self % layers)
688+ self % layers(n) % optimizer = sgd()
689+ end do
690+
667691 end if
692+
668693 call self % optimizer % init(self % get_num_params())
694+
695+ do n = 1 , size (self % layers)
696+ call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
697+ end do
698+
669699 end if
670700
671701 if (present (batch_size)) then
@@ -699,29 +729,29 @@ module subroutine update(self, optimizer, batch_size)
699729 type is (dense_layer)
700730 call this_layer % get_params_ptr(weights, biases)
701731 call this_layer % get_gradients_ptr(dw, db)
702- call self % optimizer % minimize(weights, dw / batch_size_)
703- call self % optimizer % minimize(biases, db / batch_size_)
732+ call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
733+ call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
704734 this_layer % dw = 0
705735 this_layer % db = 0
706736 type is (conv1d_layer)
707737 call this_layer % get_params_ptr(weights, biases)
708738 call this_layer % get_gradients_ptr(dw, db)
709- call self % optimizer % minimize(weights, dw / batch_size_)
710- call self % optimizer % minimize(biases, db / batch_size_)
739+ call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
740+ call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
711741 this_layer % dw = 0
712742 this_layer % db = 0
713743 type is (conv2d_layer)
714744 call this_layer % get_params_ptr(weights, biases)
715745 call this_layer % get_gradients_ptr(dw, db)
716- call self % optimizer % minimize(weights, dw / batch_size_)
717- call self % optimizer % minimize(biases, db / batch_size_)
746+ call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
747+ call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
718748 this_layer % dw = 0
719749 this_layer % db = 0
720750 type is (locally_connected1d_layer)
721751 call this_layer % get_params_ptr(weights, biases)
722752 call this_layer % get_gradients_ptr(dw, db)
723- call self % optimizer % minimize(weights, dw / batch_size_)
724- call self % optimizer % minimize(biases, db / batch_size_)
753+ call self % layers(n) % optimizer % minimize(weights, dw / batch_size_)
754+ call self % layers(n) % optimizer % minimize(biases, db / batch_size_)
725755 this_layer % dw = 0
726756 this_layer % db = 0
727757 end select
0 commit comments