Skip to content

Commit e61f29e

Browse files
committed
Remove optimizer as component to the network class
1 parent 309ef6e commit e61f29e

File tree

2 files changed

+10
-42
lines changed

2 files changed

+10
-42
lines changed

src/nf/nf_network.f90

Lines changed: 0 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,6 @@ module nf_network
1616

1717
type(layer), allocatable :: layers(:)
1818
class(loss_type), allocatable :: loss
19-
class(optimizer_base_type), allocatable :: optimizer
2019

2120
contains
2221

src/nf/nf_network_submodule.f90

Lines changed: 10 additions & 41 deletions
Original file line numberDiff line numberDiff line change
@@ -574,27 +574,8 @@ module subroutine train(self, input_data, output_data, batch_size, &
574574
integer :: i, j, n
575575
integer :: istart, iend, indices(2)
576576

577-
! Passing the optimizer instance is optional.
578-
! If not provided, we default to SGD with its default settings.
579-
if (present(optimizer)) then
580-
self % optimizer = optimizer
581-
582-
do n = 1, size(self % layers)
583-
self % layers(n) % optimizer = optimizer
584-
end do
585-
586-
else
587-
self % optimizer = sgd()
588-
589-
do n = 1, size(self % layers)
590-
self % layers(n) % optimizer = sgd()
591-
end do
592-
593-
end if
594-
595-
do n = 1, size(self % layers)
596-
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
597-
end do
577+
! The optional optimizer instance is passed through to the update() method
578+
! where it is optional as well.
598579

599580
! Passing the loss instance is optional.
600581
! If not provided, we default to quadratic().
@@ -628,7 +609,7 @@ module subroutine train(self, input_data, output_data, batch_size, &
628609
call self % backward(output_data(:,j))
629610
end do
630611

631-
call self % update(batch_size=batch_size)
612+
call self % update(optimizer=optimizer, batch_size=batch_size)
632613

633614
end do batch_loop
634615
end do epoch_loop
@@ -645,34 +626,22 @@ module subroutine update(self, optimizer, batch_size)
645626
real, pointer :: weights(:), biases(:), dw(:), db(:)
646627
integer :: n
647628

648-
! Passing the optimizer instance is optional. If not provided, and if the
649-
! optimizer has not already been set, we default to the default SGD. The
650-
! instantiation and initialization below of the optimizer is normally done
651-
! at the beginning of the network % train() method. However, if the user
652-
! wants to call network % update() directly, for example if they use their
653-
! own custom mini-batching routine, we initialize the optimizer here as
654-
! well. If it's initialized already, this step is a cheap no-op.
655-
if (.not. allocated(self % optimizer)) then
629+
! You can optionally pass an optimizer instance to the update() method.
630+
! This is necessary if you're not using the train() method, for example if
631+
! you're using your own custom mini-batching routine and calling the
632+
! forward(), backward(), and update() methods directly.
633+
if (.not. allocated(self % layers(1) % optimizer)) then
656634
if (present(optimizer)) then
657-
self % optimizer = optimizer
658-
659635
do n = 1, size(self % layers)
660636
self % layers(n) % optimizer = optimizer
637+
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
661638
end do
662-
663639
else
664-
self % optimizer = sgd()
665-
666640
do n = 1, size(self % layers)
667641
self % layers(n) % optimizer = sgd()
642+
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
668643
end do
669-
670644
end if
671-
672-
do n = 1, size(self % layers)
673-
call self % layers(n) % optimizer % init(self % layers(n) % get_num_params())
674-
end do
675-
676645
end if
677646

678647
if (present(batch_size)) then

0 commit comments

Comments
 (0)