@@ -574,27 +574,8 @@ module subroutine train(self, input_data, output_data, batch_size, &
574574 integer :: i, j, n
575575 integer :: istart, iend, indices(2 )
576576
577- ! Passing the optimizer instance is optional.
578- ! If not provided, we default to SGD with its default settings.
579- if (present (optimizer)) then
580- self % optimizer = optimizer
581-
582- do n = 1 , size (self % layers)
583- self % layers(n) % optimizer = optimizer
584- end do
585-
586- else
587- self % optimizer = sgd()
588-
589- do n = 1 , size (self % layers)
590- self % layers(n) % optimizer = sgd()
591- end do
592-
593- end if
594-
595- do n = 1 , size (self % layers)
596- call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
597- end do
577+ ! The optional optimizer instance is passed through to the update() method
578+ ! where it is optional as well.
598579
599580 ! Passing the loss instance is optional.
600581 ! If not provided, we default to quadratic().
@@ -628,7 +609,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
628609 call self % backward(output_data(:,j))
629610 end do
630611
631- call self % update(batch_size= batch_size)
612+ call self % update(optimizer = optimizer, batch_size= batch_size)
632613
633614 end do batch_loop
634615 end do epoch_loop
@@ -645,34 +626,22 @@ module subroutine update(self, optimizer, batch_size)
645626 real , pointer :: weights(:), biases(:), dw(:), db(:)
646627 integer :: n
647628
648- ! Passing the optimizer instance is optional. If not provided, and if the
649- ! optimizer has not already been set, we default to the default SGD. The
650- ! instantiation and initialization below of the optimizer is normally done
651- ! at the beginning of the network % train() method. However, if the user
652- ! wants to call network % update() directly, for example if they use their
653- ! own custom mini-batching routine, we initialize the optimizer here as
654- ! well. If it's initialized already, this step is a cheap no-op.
655- if (.not. allocated (self % optimizer)) then
629+ ! You can optionally pass an optimizer instance to the update() method.
630+ ! This is necessary if you're not using the train() method, for example if
631+ ! you're using your own custom mini-batching routine and calling the
632+ ! forward(), backward(), and update() methods directly.
633+ if (.not. allocated (self % layers(1 ) % optimizer)) then
656634 if (present (optimizer)) then
657- self % optimizer = optimizer
658-
659635 do n = 1 , size (self % layers)
660636 self % layers(n) % optimizer = optimizer
637+ call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
661638 end do
662-
663639 else
664- self % optimizer = sgd()
665-
666640 do n = 1 , size (self % layers)
667641 self % layers(n) % optimizer = sgd()
642+ call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
668643 end do
669-
670644 end if
671-
672- do n = 1 , size (self % layers)
673- call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
674- end do
675-
676645 end if
677646
678647 if (present (batch_size)) then
0 commit comments