Skip to content

Commit 1b81338

Browse files
committed
linear2d_layer: fix gradient updates
1 parent aa4f8f2 commit 1b81338

File tree

2 files changed

+69
-2
lines changed

2 files changed

+69
-2
lines changed

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -77,7 +77,7 @@ module function get_params(self) result(params)
7777

7878
real, pointer :: w_(:) => null()
7979

80-
w_(1:size(self % weights)) => self % weights
80+
w_(1: product(shape(self % weights))) => self % weights
8181

8282
params = [ &
8383
w_, &
@@ -93,7 +93,7 @@ module function get_gradients(self) result(gradients)
9393

9494
real, pointer :: dw_(:) => null()
9595

96-
dw_(1:size(self % dw)) => self % dw
96+
dw_(1: product(shape(self % dw))) => self % dw
9797

9898
gradients = [ &
9999
dw_, &

test/test_linear2d_layer.f90

Lines changed: 67 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@ program test_linear2d_layer
1616

1717
call test_linear2d_layer_forward(linear, ok, sample_input)
1818
call test_linear2d_layer_backward(linear, ok, sample_input, sample_gradient)
19+
call test_linear2d_layer_gradient_updates(ok)
1920

2021
contains
2122
subroutine test_linear2d_layer_forward(linear, ok, input)
@@ -100,4 +101,70 @@ subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
100101
write(stderr, '(a)') 'backward returned incorrect db values.. failed'
101102
end if
102103
end subroutine test_linear2d_layer_backward
104+
105+
subroutine test_linear2d_layer_gradient_updates(ok)
106+
logical, intent(in out) :: ok
107+
real :: input(3, 4, 1) = reshape([0.0, 0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12], [3, 4, 1])
108+
real :: gradient(3, 2, 1) = reshape([0.0, 10., 0.2, 3., 0.4, 1.], [3, 2, 1])
109+
type(linear2d_layer) :: linear
110+
111+
integer :: num_parameters
112+
real :: parameters(10)
113+
real :: expected_parameters(10) = [&
114+
0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001, 0.100000001,&
115+
0.109999999, 0.109999999&
116+
]
117+
real :: gradients(10)
118+
real :: expected_gradients(10) = [&
119+
1.03999996, 4.09999990, 7.15999985, 1.12400007, 0.240000010, 1.56000006, 2.88000011, 2.86399961,&
120+
10.1999998, 4.40000010&
121+
]
122+
real :: updated_parameters(10)
123+
real :: updated_weights(8)
124+
real :: updated_biases(2)
125+
real :: expected_weights(8) = [&
126+
0.203999996, 0.509999990, 0.816000044, 0.212400019, 0.124000005, 0.256000012, 0.388000011, 0.386399955&
127+
]
128+
real :: expected_biases(2) = [1.13000000, 0.550000012]
129+
130+
integer :: i
131+
132+
linear = linear2d_layer(sequence_length=3, in_features=4, out_features=2, batch_size=1)
133+
call linear % init([4])
134+
call linear % forward(input)
135+
call linear % backward(input, gradient)
136+
137+
num_parameters = linear % get_num_params()
138+
if (num_parameters /= 10) then
139+
ok = .false.
140+
write(stderr, '(a)') 'incorrect number of parameters.. failed'
141+
end if
142+
143+
parameters = linear % get_params()
144+
if (.not. all(parameters.eq.expected_parameters)) then
145+
ok = .false.
146+
write(stderr, '(a)') 'incorrect parameters.. failed'
147+
end if
148+
149+
gradients = linear % get_gradients()
150+
if (.not. all(gradients.eq.expected_gradients)) then
151+
ok = .false.
152+
write(stderr, '(a)') 'incorrect gradients.. failed'
153+
end if
154+
155+
do i = 1, num_parameters
156+
updated_parameters(i) = parameters(i) + 0.1 * gradients(i)
157+
end do
158+
call linear % set_params(updated_parameters)
159+
updated_weights = reshape(linear % weights, shape(expected_weights))
160+
if (.not. all(updated_weights.eq.expected_weights)) then
161+
ok = .false.
162+
write(stderr, '(a)') 'incorrect updated weights.. failed'
163+
end if
164+
updated_biases = linear % biases
165+
if (.not. all(updated_biases.eq.expected_biases)) then
166+
ok = .false.
167+
write(stderr, '(a)') 'incorrect updated biases.. failed'
168+
end if
169+
end subroutine test_linear2d_layer_gradient_updates
103170
end program test_linear2d_layer

0 commit comments

Comments
 (0)