33 implicit none
44contains
55 module function linear2d_layer_cons (&
6- batch_size , sequence_length , in_features , out_features&
6+ sequence_length , in_features , out_features , batch_size &
77 ) result(res)
88 integer , intent (in ) :: batch_size, sequence_length, in_features, out_features
99 type (linear2d_layer) :: res
@@ -18,8 +18,8 @@ module subroutine init(self, input_shape)
1818 class(linear2d_layer), intent (in out ) :: self
1919 integer , intent (in ) :: input_shape(:)
2020
21- allocate (self % output(self % batch_size , self % sequence_length , self % out_features ))
22- allocate (self % gradient(self % batch_size , self % sequence_length , self % in_features ))
21+ allocate (self % output(self % sequence_length , self % out_features , self % batch_size ))
22+ allocate (self % gradient(self % sequence_length , self % in_features , self % batch_size ))
2323
2424 allocate (self % weights(self % in_features, self % out_features))
2525 self % weights = 0.1
@@ -39,10 +39,10 @@ pure module subroutine forward(self, input)
3939 integer :: i, j
4040
4141 do concurrent(i = 1 : self % batch_size)
42- self % output(i , :, : ) = matmul (input(i , :, : ), self % weights)
42+ self % output(: , :, i ) = matmul (input(: , :, i ), self % weights)
4343 end do
4444 do concurrent(i = 1 : self % batch_size, j = 1 : self % sequence_length)
45- self % output(i, j, :) = self % output(i, j, :) + self % biases
45+ self % output(j, :, i ) = self % output(j, :, i ) + self % biases
4646 end do
4747 end subroutine forward
4848
@@ -55,9 +55,9 @@ pure module subroutine backward(self, input, gradient)
5555 integer :: i
5656
5757 do concurrent(i = 1 : self % batch_size)
58- self % dw = self % dw + matmul (transpose (input(i , :, : )), gradient(i , :, : ))
59- self % db = self % db + sum (gradient(i , :, : ), 1 )
60- self % gradient(i , :, : ) = matmul (gradient(i , :, : ), transpose (self % weights))
58+ self % dw = self % dw + matmul (transpose (input(: , :, i )), gradient(: , :, i ))
59+ self % db = self % db + sum (gradient(: , :, i ), 1 )
60+ self % gradient(: , :, i ) = matmul (gradient(: , :, i ), transpose (self % weights))
6161 end do
6262 end subroutine backward
6363
0 commit comments