Skip to content

Commit c6b4d87

Browse files
committed
Bug fix
1 parent a055b20 commit c6b4d87

File tree

3 files changed

+6
-6
lines changed

3 files changed

+6
-6
lines changed

src/nf/nf_conv1d_layer_submodule.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -157,8 +157,8 @@ module function get_params(self) result(params)
157157
class(conv1d_layer), intent(in), target :: self
158158
real, allocatable :: params(:)
159159
real, pointer :: w_(:) => null()
160-
w_(1:size(self % z)) => self % z
161-
params = [ w_]
160+
w_(1:size(self % kernel)) => self % kernel
161+
params = [ w_, self % biases]
162162
end function get_params
163163

164164
module function get_gradients(self) result(gradients)

src/nf/nf_conv2d_layer_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -195,11 +195,11 @@ module function get_params(self) result(params)
195195

196196
real, pointer :: w_(:) => null()
197197

198-
w_(1:size(self % z)) => self % z
198+
w_(1:size(self % kernel)) => self % kernel
199199

200200
params = [ &
201-
w_ &
202-
!self % biases &
201+
w_, &
202+
self % biases &
203203
]
204204

205205
end function get_params

test/test_conv2d_network.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -60,7 +60,7 @@ program test_conv2d_network
6060
call cnn % forward(sample_input)
6161
call cnn % backward(y)
6262
call cnn % update(optimizer=sgd(learning_rate=1.))
63-
o = cnn % layers(2) % get_params()
63+
o = cnn % layers(3) % get_params()
6464
print *, o
6565
if (all(abs(cnn % predict(sample_input) - y) < tolerance)) exit
6666
end do

0 commit comments

Comments
 (0)