Skip to content

Commit aa4f8f2

Browse files
committed
make linear2d_layer with batch as last dimension (performance)
1 parent 5b3a4cb commit aa4f8f2

File tree

4 files changed

+14
-14
lines changed

4 files changed

+14
-14
lines changed

src/nf/nf_layer_constructors.f90

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -166,7 +166,7 @@ module function reshape(output_shape) result(res)
166166
!! Resulting layer instance
167167
end function reshape
168168

169-
module function linear2d(batch_size, sequence_length, in_features, out_features) result(res)
169+
module function linear2d(sequence_length, in_features, out_features, batch_size) result(res)
170170
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
171171
type(layer) :: res
172172
end function linear2d

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -135,13 +135,13 @@ module function reshape(output_shape) result(res)
135135

136136
end function reshape
137137

138-
module function linear2d(batch_size, sequence_length, in_features, out_features) result(res)
138+
module function linear2d(sequence_length, in_features, out_features, batch_size) result(res)
139139
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
140140
type(layer) :: res
141141

142142
res % name = 'linear2d'
143-
res % layer_shape = [batch_size, sequence_length, out_features]
144-
allocate(res % p, source=linear2d_layer(batch_size, sequence_length, in_features, out_features))
143+
res % layer_shape = [sequence_length, out_features, batch_size]
144+
allocate(res % p, source=linear2d_layer(sequence_length, in_features, out_features, batch_size))
145145
end function linear2d
146146

147147
end submodule nf_layer_constructors_submodule

src/nf/nf_linear2d_layer.f90

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@ module nf_linear2d_layer
99
public :: linear2d_layer
1010

1111
type, extends(base_layer) :: linear2d_layer
12-
integer :: batch_size, sequence_length, in_features, out_features
12+
integer :: sequence_length, in_features, out_features, batch_size
1313

1414
real, allocatable :: weights(:, :)
1515
real, allocatable :: biases(:)
@@ -32,7 +32,7 @@ module nf_linear2d_layer
3232

3333
interface linear2d_layer
3434
module function linear2d_layer_cons(&
35-
batch_size, sequence_length, in_features, out_features&
35+
sequence_length, in_features, out_features, batch_size&
3636
) result(res)
3737
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
3838
type(linear2d_layer) :: res

src/nf/nf_linear2d_layer_submodule.f90

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
implicit none
44
contains
55
module function linear2d_layer_cons(&
6-
batch_size, sequence_length, in_features, out_features&
6+
sequence_length, in_features, out_features, batch_size&
77
) result(res)
88
integer, intent(in) :: batch_size, sequence_length, in_features, out_features
99
type(linear2d_layer) :: res
@@ -18,8 +18,8 @@ module subroutine init(self, input_shape)
1818
class(linear2d_layer), intent(in out) :: self
1919
integer, intent(in) :: input_shape(:)
2020

21-
allocate(self % output(self % batch_size, self % sequence_length, self % out_features))
22-
allocate(self % gradient(self % batch_size, self % sequence_length, self % in_features))
21+
allocate(self % output(self % sequence_length, self % out_features, self % batch_size))
22+
allocate(self % gradient(self % sequence_length, self % in_features, self % batch_size))
2323

2424
allocate(self % weights(self % in_features, self % out_features))
2525
self % weights = 0.1
@@ -39,10 +39,10 @@ pure module subroutine forward(self, input)
3939
integer :: i, j
4040

4141
do concurrent(i = 1: self % batch_size)
42-
self % output(i, :, :) = matmul(input(i, :, :), self % weights)
42+
self % output(:, :, i) = matmul(input(:, :, i), self % weights)
4343
end do
4444
do concurrent(i = 1: self % batch_size, j = 1: self % sequence_length)
45-
self % output(i, j, :) = self % output(i, j, :) + self % biases
45+
self % output(j, :, i) = self % output(j, :, i) + self % biases
4646
end do
4747
end subroutine forward
4848

@@ -55,9 +55,9 @@ pure module subroutine backward(self, input, gradient)
5555
integer :: i
5656

5757
do concurrent(i = 1: self % batch_size)
58-
self % dw = self % dw + matmul(transpose(input(i, :, :)), gradient(i, :, :))
59-
self % db = self % db + sum(gradient(i, :, :), 1)
60-
self % gradient(i, :, :) = matmul(gradient(i, :, :), transpose(self % weights))
58+
self % dw = self % dw + matmul(transpose(input(:, :, i)), gradient(:, :, i))
59+
self % db = self % db + sum(gradient(:, :, i), 1)
60+
self % gradient(:, :, i) = matmul(gradient(:, :, i), transpose(self % weights))
6161
end do
6262
end subroutine backward
6363

0 commit comments

Comments
 (0)