Skip to content

Commit ccc180e

Browse files
committed
layernorm: public api
1 parent bdefd02 commit ccc180e

File tree

5 files changed

+78
-15
lines changed

5 files changed

+78
-15
lines changed

src/nf.f90

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,8 @@ module nf
1111
linear2d, &
1212
maxpool2d, &
1313
reshape, &
14-
self_attention
14+
self_attention, &
15+
layer_normalization
1516
use nf_loss, only: mse, quadratic
1617
use nf_metrics, only: corr, maxabs
1718
use nf_network, only: network

src/nf/nf_layer_constructors.f90

Lines changed: 19 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,8 @@ module nf_layer_constructors
1717
linear2d, &
1818
maxpool2d, &
1919
reshape, &
20-
self_attention
20+
self_attention, &
21+
layer_normalization
2122

2223
interface input
2324

@@ -222,15 +223,23 @@ module function linear2d(out_features) result(res)
222223
!! Resulting layer instance
223224
end function linear2d
224225

225-
module function self_attention(num_heads) result(res)
226-
!! Rank-2 (sequence_length, out_features) self attention constructor.
227-
!! sequence_length and model_dimension are determined at layer initialization, based on the
228-
!! output shape of the previous layer.
229-
integer, intent(in) :: num_heads
230-
!! Number of attention heads
231-
type(layer) :: res
232-
!! Resulting layer instance
233-
end function self_attention
226+
module function self_attention(num_heads) result(res)
227+
!! Rank-2 (sequence_length, out_features) self attention constructor.
228+
!! sequence_length and model_dimension are determined at layer initialization, based on the
229+
!! output shape of the previous layer.
230+
integer, intent(in) :: num_heads
231+
!! Number of attention heads
232+
type(layer) :: res
233+
!! Resulting layer instance
234+
end function self_attention
235+
236+
module function layer_normalization() result(res)
237+
!! Layer Normalization
238+
!! ((x − mean(x)) / sqrt(variance(x) + eps) * gamma + beta
239+
!! Based upon `Ba, Jimmy Lei, Jamie Ryan Kiros, and Geoffrey E. Hinton(2016)`:
240+
!! https://arxiv.org/abs/1607.06450v1
241+
type(layer) :: res
242+
end function layer_normalization
234243

235244
end interface
236245

src/nf/nf_layer_constructors_submodule.f90

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_layernorm_layer, only: layernorm_layer
1516
use nf_activation, only: activation_function, relu, sigmoid
1617

1718
implicit none
@@ -179,4 +180,11 @@ module function self_attention(num_heads) result(res)
179180
allocate(res % p, source=self_attention_layer(num_heads))
180181
end function self_attention
181182

183+
module function layer_normalization() result(res)
184+
type(layer) :: res
185+
186+
res % name = 'layer_normalization'
187+
allocate(res % p, source=layernorm_layer())
188+
end function layer_normalization
189+
182190
end submodule nf_layer_constructors_submodule

src/nf/nf_layer_submodule.f90

Lines changed: 46 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
use nf_reshape_layer, only: reshape3d_layer
1313
use nf_linear2d_layer, only: linear2d_layer
1414
use nf_self_attention_layer, only: self_attention_layer
15+
use nf_layernorm_layer, only: layernorm_layer
1516
use nf_optimizers, only: optimizer_base_type
1617

1718
contains
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061
call this_layer % backward(prev_layer % output, gradient)
6162
type is(self_attention_layer)
6263
call this_layer % backward(prev_layer % output, gradient)
64+
type is(layernorm_layer)
65+
call this_layer % backward(prev_layer % output, gradient)
6366
end select
6467

6568
end select
@@ -84,6 +87,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8487
call this_layer % backward(prev_layer % output, gradient)
8588
type is(self_attention_layer)
8689
call this_layer % backward(prev_layer % output, gradient)
90+
type is(layernorm_layer)
91+
call this_layer % backward(prev_layer % output, gradient)
8792
end select
8893

8994
type is(self_attention_layer)
@@ -95,8 +100,18 @@ pure module subroutine backward_2d(self, previous, gradient)
95100
call this_layer % backward(prev_layer % output, gradient)
96101
type is(self_attention_layer)
97102
call this_layer % backward(prev_layer % output, gradient)
103+
type is(layernorm_layer)
104+
call this_layer % backward(prev_layer % output, gradient)
98105
end select
99106

107+
type is(layernorm_layer)
108+
109+
select type(prev_layer => previous % p)
110+
type is(linear2d_layer)
111+
call this_layer % backward(prev_layer % output, gradient)
112+
type is(self_attention_layer)
113+
call this_layer % backward(prev_layer % output, gradient)
114+
end select
100115
end select
101116

102117
end subroutine backward_2d
@@ -250,26 +265,40 @@ module subroutine forward(self, input)
250265

251266
type is(linear2d_layer)
252267

253-
! Upstream layers permitted: input2d, linear2d
268+
! Upstream layers permitted: input2d, linear2d, self_attention, layer_normalization
254269
select type(prev_layer => input % p)
255270
type is(input2d_layer)
256271
call this_layer % forward(prev_layer % output)
257272
type is(linear2d_layer)
258273
call this_layer % forward(prev_layer % output)
259274
type is(self_attention_layer)
260275
call this_layer % forward(prev_layer % output)
276+
type is(layernorm_layer)
277+
call this_layer % forward(prev_layer % output)
261278
end select
262279

263280
type is(self_attention_layer)
264281

265-
! Upstream layers permitted: input2d, linear2d
282+
! Upstream layers permitted: input2d, linear2d, self_attention, layer_normalization
266283
select type(prev_layer => input % p)
267284
type is(input2d_layer)
268285
call this_layer % forward(prev_layer % output)
269286
type is(linear2d_layer)
270287
call this_layer % forward(prev_layer % output)
271288
type is(self_attention_layer)
272289
call this_layer % forward(prev_layer % output)
290+
type is(layernorm_layer)
291+
call this_layer % forward(prev_layer % output)
292+
end select
293+
294+
type is(layernorm_layer)
295+
296+
! Upstream layers permitted: linear2d, self_attention
297+
select type(prev_layer => input % p)
298+
type is(linear2d_layer)
299+
call this_layer % forward(prev_layer % output)
300+
type is(self_attention_layer)
301+
call this_layer % forward(prev_layer % output)
273302
end select
274303

275304
end select
@@ -311,6 +340,8 @@ pure module subroutine get_output_2d(self, output)
311340
allocate(output, source=this_layer % output)
312341
type is(self_attention_layer)
313342
allocate(output, source=this_layer % output)
343+
type is(layernorm_layer)
344+
allocate(output, source=this_layer % output)
314345
class default
315346
error stop '2-d output can only be read from an input2d or linear2d layer.'
316347

@@ -354,8 +385,8 @@ impure elemental module subroutine init(self, input)
354385
call this_layer % init(input % layer_shape)
355386
end select
356387

357-
! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358-
! self_attention layers is not known until we receive an input layer.
388+
! The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
389+
! self_attention or layernorm layers is not known until we receive an input layer.
359390
select type(this_layer => self % p)
360391
type is(conv2d_layer)
361392
self % layer_shape = shape(this_layer % output)
@@ -367,6 +398,8 @@ impure elemental module subroutine init(self, input)
367398
self % layer_shape = shape(this_layer % output)
368399
type is(self_attention_layer)
369400
self % layer_shape = shape(this_layer % output)
401+
type is(layernorm_layer)
402+
self % layer_shape = shape(this_layer % output)
370403
type is(maxpool2d_layer)
371404
self % layer_shape = shape(this_layer % output)
372405
end select
@@ -425,6 +458,8 @@ elemental module function get_num_params(self) result(num_params)
425458
num_params = this_layer % get_num_params()
426459
type is (self_attention_layer)
427460
num_params = this_layer % get_num_params()
461+
type is (layernorm_layer)
462+
num_params = this_layer % get_num_params()
428463
class default
429464
error stop 'Unknown layer type.'
430465
end select
@@ -458,6 +493,8 @@ module function get_params(self) result(params)
458493
params = this_layer % get_params()
459494
type is (self_attention_layer)
460495
params = this_layer % get_params()
496+
type is (layernorm_layer)
497+
params = this_layer % get_params()
461498
class default
462499
error stop 'Unknown layer type.'
463500
end select
@@ -491,6 +528,8 @@ module function get_gradients(self) result(gradients)
491528
gradients = this_layer % get_gradients()
492529
type is (self_attention_layer)
493530
gradients = this_layer % get_gradients()
531+
type is (layernorm_layer)
532+
gradients = this_layer % get_gradients()
494533
class default
495534
error stop 'Unknown layer type.'
496535
end select
@@ -549,6 +588,9 @@ module subroutine set_params(self, params)
549588
type is (self_attention_layer)
550589
call this_layer % set_params(params)
551590

591+
type is (layernorm_layer)
592+
call this_layer % set_params(params)
593+
552594
type is (maxpool2d_layer)
553595
! No parameters to set.
554596
write(stderr, '(a)') 'Warning: calling set_params() ' &

src/nf/nf_network_submodule.f90

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
use nf_reshape_layer, only: reshape3d_layer
1212
use nf_linear2d_layer, only: linear2d_layer
1313
use nf_self_attention_layer, only: self_attention_layer
14+
use nf_layernorm_layer, only: layernorm_layer
1415
use nf_layer, only: layer
1516
use nf_layer_constructors, only: conv2d, dense, flatten, input, maxpool2d, reshape
1617
use nf_loss, only: quadratic
@@ -163,6 +164,8 @@ module subroutine backward(self, output, loss)
163164
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
164165
type is(self_attention_layer)
165166
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
167+
type is(layernorm_layer)
168+
call self % layers(n) % backward(self % layers(n - 1), next_layer % gradient)
166169
end select
167170
end if
168171

0 commit comments

Comments
 (0)