Skip to content

Commit ee7fdc9

Browse files
committed
dropout % backward() doesn't need input from the previous layer
1 parent 8961f75 commit ee7fdc9

File tree

3 files changed

+4
-22
lines changed

3 files changed

+4
-22
lines changed

src/nf/nf_dropout_layer.f90

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -48,14 +48,12 @@ end function dropout_layer_cons
4848

4949
interface
5050

51-
pure module subroutine backward(self, input, gradient)
51+
pure module subroutine backward(self, gradient)
5252
!! Apply the backward gradient descent pass.
5353
!! Only weight and bias gradients are updated in this subroutine,
5454
!! while the weights and biases themselves are untouched.
5555
class(dropout_layer), intent(in out) :: self
5656
!! Dropout layer instance
57-
real, intent(in) :: input(:)
58-
!! Input from the previous layer
5957
real, intent(in) :: gradient(:)
6058
!! Gradient from the next layer
6159
end subroutine backward

src/nf/nf_dropout_layer_submodule.f90

Lines changed: 2 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -59,18 +59,10 @@ module subroutine forward(self, input)
5959
end subroutine forward
6060

6161

62-
pure module subroutine backward(self, input, gradient)
62+
pure module subroutine backward(self, gradient)
6363
class(dropout_layer), intent(in out) :: self
64-
real, intent(in) :: input(:)
6564
real, intent(in) :: gradient(:)
66-
67-
if (self % training) then
68-
! Backpropagate gradient through dropout mask
69-
self % gradient = gradient * self % mask * self % scale
70-
else
71-
! In inference mode, pass through the gradient unchanged
72-
self % gradient = gradient
73-
end if
65+
self % gradient = gradient * self % mask * self % scale
7466
end subroutine backward
7567

7668
end submodule nf_dropout_layer_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 1 addition & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -40,16 +40,8 @@ pure module subroutine backward_1d(self, previous, gradient)
4040
end select
4141

4242
type is(dropout_layer)
43-
4443
! Upstream layers permitted: input1d, dense, dropout, flatten
45-
select type(prev_layer => previous % p)
46-
type is(input1d_layer)
47-
call this_layer % backward(prev_layer % output, gradient)
48-
type is(dense_layer)
49-
call this_layer % backward(prev_layer % output, gradient)
50-
type is(flatten_layer)
51-
call this_layer % backward(prev_layer % output, gradient)
52-
end select
44+
call this_layer % backward(gradient)
5345

5446
type is(flatten_layer)
5547

0 commit comments

Comments
 (0)