Skip to content

Commit 6f33ebe

Browse files
committed
Tidy up
1 parent 119a6c8 commit 6f33ebe

File tree

2 files changed

+29
-17
lines changed

2 files changed

+29
-17
lines changed

example/linear2d.f90

Lines changed: 8 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -4,14 +4,14 @@ program linear2d_example
44
implicit none
55

66
type(network) :: net
7-
type(mse) :: loss
7+
type(mse) :: loss = mse()
88
real :: x(3, 4) = reshape( &
99
[0.1, 0.2, 0.3, 0.4, 0.5, 0.6, 0.7, 0.8, 0.9, 0.11, 0.12, 0.13], &
1010
[3, 4])
1111
real :: y(3) = [0.12, 0.1, 0.2]
1212
real :: preds(3)
1313
real :: loss_value
14-
integer, parameter :: num_iterations = 500
14+
integer, parameter :: num_iterations = 10000
1515
integer :: n
1616

1717
net = network([ &
@@ -21,19 +21,22 @@ program linear2d_example
2121
])
2222

2323
call net % print_info()
24-
loss = mse()
2524

2625
do n = 1, num_iterations
26+
2727
call net % forward(x)
2828
call net % backward(y, loss)
2929
call net % update(optimizer=sgd(learning_rate=0.01))
30+
3031
preds = net % predict(x)
31-
print '(i4,3(3x,f8.6))', n, preds
32+
print '(i5,3(3x,f9.6))', n, preds
33+
3234
loss_value = loss % eval (y, preds)
33-
if (loss_value < 0.01) then
35+
if (loss_value < 1e-4) then
3436
print *, 'Loss: ', loss_value
3537
return
3638
end if
39+
3740
end do
3841

3942
end program linear2d_example

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 21 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,20 +2,23 @@
22
use nf_base_layer, only: base_layer
33
use nf_random, only: random_normal
44
implicit none
5+
56
contains
7+
68
module function linear2d_layer_cons(out_features) result(res)
79
integer, intent(in) :: out_features
810
type(linear2d_layer) :: res
911

1012
res % out_features = out_features
1113
end function linear2d_layer_cons
1214

15+
1316
module subroutine init(self, input_shape)
1417
class(linear2d_layer), intent(in out) :: self
1518
integer, intent(in) :: input_shape(:)
1619

1720
if (size(input_shape) /= 2) then
18-
error stop "Linear2D Layer accepts 2D input"
21+
error stop "linear2d layer requires 2D input."
1922
end if
2023
self % sequence_length = input_shape(1)
2124
self % in_features = input_shape(2)
@@ -30,40 +33,45 @@ module subroutine init(self, input_shape)
3033
call random_normal(self % biases)
3134

3235
allocate(self % dw(self % in_features, self % out_features))
33-
self % dw = 0.0
36+
self % dw = 0
3437
allocate(self % db(self % out_features))
35-
self % db = 0.0
38+
self % db = 0
39+
3640
end subroutine init
3741

42+
3843
pure module subroutine forward(self, input)
3944
class(linear2d_layer), intent(in out) :: self
4045
real, intent(in) :: input(:, :)
4146
integer :: i
4247

43-
self % output(:, :) = matmul(input(:, :), self % weights)
44-
do concurrent(i = 1: self % sequence_length)
45-
self % output(i, :) = self % output(i, :) + self % biases
48+
self % output(:,:) = matmul(input(:,:), self % weights)
49+
do concurrent(i = 1:self % sequence_length)
50+
self % output(i,:) = self % output(i,:) + self % biases
4651
end do
52+
4753
end subroutine forward
4854

55+
4956
pure module subroutine backward(self, input, gradient)
5057
class(linear2d_layer), intent(in out) :: self
51-
real, intent(in) :: input(:, :)
52-
real, intent(in) :: gradient(:, :)
58+
real, intent(in) :: input(:,:)
59+
real, intent(in) :: gradient(:,:)
5360
real :: db(self % out_features)
5461
real :: dw(self % in_features, self % out_features)
5562
integer :: i
5663

57-
self % dw = self % dw + matmul(transpose(input(:, :)), gradient(:, :))
58-
self % db = self % db + sum(gradient(:, :), 1)
59-
self % gradient(:, :) = matmul(gradient(:, :), transpose(self % weights))
64+
self % dw = self % dw + matmul(transpose(input(:,:)), gradient(:,:))
65+
self % db = self % db + sum(gradient(:,:), 1)
66+
self % gradient(:,:) = matmul(gradient(:,:), transpose(self % weights))
6067
end subroutine backward
6168

69+
6270
pure module function get_num_params(self) result(num_params)
6371
class(linear2d_layer), intent(in) :: self
6472
integer :: num_params
6573

66-
! Number of weigths times number of biases
74+
! Number of weights times number of biases
6775
num_params = self % in_features * self % out_features + self % out_features
6876

6977
end function get_num_params
@@ -122,4 +130,5 @@ module subroutine set_params(self, params)
122130
end associate
123131

124132
end subroutine set_params
133+
125134
end submodule nf_linear2d_layer_submodule

0 commit comments

Comments
 (0)