@@ -16,6 +16,7 @@ program test_linear2d_layer
1616
1717 call test_linear2d_layer_forward(linear, ok, sample_input)
1818 call test_linear2d_layer_backward(linear, ok, sample_input, sample_gradient)
19+ call test_linear2d_layer_gradient_updates(ok)
1920
2021contains
2122 subroutine test_linear2d_layer_forward (linear , ok , input )
@@ -100,4 +101,70 @@ subroutine test_linear2d_layer_backward(linear, ok, input, gradient)
100101 write (stderr, ' (a)' ) ' backward returned incorrect db values.. failed'
101102 end if
102103 end subroutine test_linear2d_layer_backward
104+
105+ subroutine test_linear2d_layer_gradient_updates (ok )
106+ logical , intent (in out ) :: ok
107+ real :: input(3 , 4 , 1 ) = reshape ([0.0 , 0.1 , 0.2 , 0.3 , 0.4 , 0.5 , 0.6 , 0.7 , 0.8 , 0.9 , 0.11 , 0.12 ], [3 , 4 , 1 ])
108+ real :: gradient(3 , 2 , 1 ) = reshape ([0.0 , 10 ., 0.2 , 3 ., 0.4 , 1 .], [3 , 2 , 1 ])
109+ type (linear2d_layer) :: linear
110+
111+ integer :: num_parameters
112+ real :: parameters(10 )
113+ real :: expected_parameters(10 ) = [&
114+ 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 , 0.100000001 ,&
115+ 0.109999999 , 0.109999999 &
116+ ]
117+ real :: gradients(10 )
118+ real :: expected_gradients(10 ) = [&
119+ 1.03999996 , 4.09999990 , 7.15999985 , 1.12400007 , 0.240000010 , 1.56000006 , 2.88000011 , 2.86399961 ,&
120+ 10.1999998 , 4.40000010 &
121+ ]
122+ real :: updated_parameters(10 )
123+ real :: updated_weights(8 )
124+ real :: updated_biases(2 )
125+ real :: expected_weights(8 ) = [&
126+ 0.203999996 , 0.509999990 , 0.816000044 , 0.212400019 , 0.124000005 , 0.256000012 , 0.388000011 , 0.386399955 &
127+ ]
128+ real :: expected_biases(2 ) = [1.13000000 , 0.550000012 ]
129+
130+ integer :: i
131+
132+ linear = linear2d_layer(sequence_length= 3 , in_features= 4 , out_features= 2 , batch_size= 1 )
133+ call linear % init([4 ])
134+ call linear % forward(input)
135+ call linear % backward(input, gradient)
136+
137+ num_parameters = linear % get_num_params()
138+ if (num_parameters /= 10 ) then
139+ ok = .false.
140+ write (stderr, ' (a)' ) ' incorrect number of parameters.. failed'
141+ end if
142+
143+ parameters = linear % get_params()
144+ if (.not. all (parameters.eq. expected_parameters)) then
145+ ok = .false.
146+ write (stderr, ' (a)' ) ' incorrect parameters.. failed'
147+ end if
148+
149+ gradients = linear % get_gradients()
150+ if (.not. all (gradients.eq. expected_gradients)) then
151+ ok = .false.
152+ write (stderr, ' (a)' ) ' incorrect gradients.. failed'
153+ end if
154+
155+ do i = 1 , num_parameters
156+ updated_parameters(i) = parameters(i) + 0.1 * gradients(i)
157+ end do
158+ call linear % set_params(updated_parameters)
159+ updated_weights = reshape (linear % weights, shape (expected_weights))
160+ if (.not. all (updated_weights.eq. expected_weights)) then
161+ ok = .false.
162+ write (stderr, ' (a)' ) ' incorrect updated weights.. failed'
163+ end if
164+ updated_biases = linear % biases
165+ if (.not. all (updated_biases.eq. expected_biases)) then
166+ ok = .false.
167+ write (stderr, ' (a)' ) ' incorrect updated biases.. failed'
168+ end if
169+ end subroutine test_linear2d_layer_gradient_updates
103170end program test_linear2d_layer
0 commit comments