1212 use nf_reshape_layer, only: reshape3d_layer
1313 use nf_linear2d_layer, only: linear2d_layer
1414 use nf_self_attention_layer, only: self_attention_layer
15+ use nf_layernorm_layer, only: layernorm_layer
1516 use nf_optimizers, only: optimizer_base_type
1617
1718contains
@@ -60,6 +61,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6061 call this_layer % backward(prev_layer % output, gradient)
6162 type is (self_attention_layer)
6263 call this_layer % backward(prev_layer % output, gradient)
64+ type is (layernorm_layer)
65+ call this_layer % backward(prev_layer % output, gradient)
6366 end select
6467
6568 end select
@@ -84,6 +87,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8487 call this_layer % backward(prev_layer % output, gradient)
8588 type is (self_attention_layer)
8689 call this_layer % backward(prev_layer % output, gradient)
90+ type is (layernorm_layer)
91+ call this_layer % backward(prev_layer % output, gradient)
8792 end select
8893
8994 type is (self_attention_layer)
@@ -95,8 +100,18 @@ pure module subroutine backward_2d(self, previous, gradient)
95100 call this_layer % backward(prev_layer % output, gradient)
96101 type is (self_attention_layer)
97102 call this_layer % backward(prev_layer % output, gradient)
103+ type is (layernorm_layer)
104+ call this_layer % backward(prev_layer % output, gradient)
98105 end select
99106
107+ type is (layernorm_layer)
108+
109+ select type (prev_layer = > previous % p)
110+ type is (linear2d_layer)
111+ call this_layer % backward(prev_layer % output, gradient)
112+ type is (self_attention_layer)
113+ call this_layer % backward(prev_layer % output, gradient)
114+ end select
100115 end select
101116
102117 end subroutine backward_2d
@@ -250,26 +265,40 @@ module subroutine forward(self, input)
250265
251266 type is (linear2d_layer)
252267
253- ! Upstream layers permitted: input2d, linear2d
268+ ! Upstream layers permitted: input2d, linear2d, self_attention, layer_normalization
254269 select type (prev_layer = > input % p)
255270 type is (input2d_layer)
256271 call this_layer % forward(prev_layer % output)
257272 type is (linear2d_layer)
258273 call this_layer % forward(prev_layer % output)
259274 type is (self_attention_layer)
260275 call this_layer % forward(prev_layer % output)
276+ type is (layernorm_layer)
277+ call this_layer % forward(prev_layer % output)
261278 end select
262279
263280 type is (self_attention_layer)
264281
265- ! Upstream layers permitted: input2d, linear2d
282+ ! Upstream layers permitted: input2d, linear2d, self_attention, layer_normalization
266283 select type (prev_layer = > input % p)
267284 type is (input2d_layer)
268285 call this_layer % forward(prev_layer % output)
269286 type is (linear2d_layer)
270287 call this_layer % forward(prev_layer % output)
271288 type is (self_attention_layer)
272289 call this_layer % forward(prev_layer % output)
290+ type is (layernorm_layer)
291+ call this_layer % forward(prev_layer % output)
292+ end select
293+
294+ type is (layernorm_layer)
295+
296+ ! Upstream layers permitted: linear2d, self_attention
297+ select type (prev_layer = > input % p)
298+ type is (linear2d_layer)
299+ call this_layer % forward(prev_layer % output)
300+ type is (self_attention_layer)
301+ call this_layer % forward(prev_layer % output)
273302 end select
274303
275304 end select
@@ -311,6 +340,8 @@ pure module subroutine get_output_2d(self, output)
311340 allocate (output, source= this_layer % output)
312341 type is (self_attention_layer)
313342 allocate (output, source= this_layer % output)
343+ type is (layernorm_layer)
344+ allocate (output, source= this_layer % output)
314345 class default
315346 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
316347
@@ -354,8 +385,8 @@ impure elemental module subroutine init(self, input)
354385 call this_layer % init(input % layer_shape)
355386 end select
356387
357- ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
358- ! self_attention layers is not known until we receive an input layer.
388+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
389+ ! self_attention or layernorm layers is not known until we receive an input layer.
359390 select type (this_layer = > self % p)
360391 type is (conv2d_layer)
361392 self % layer_shape = shape (this_layer % output)
@@ -367,6 +398,8 @@ impure elemental module subroutine init(self, input)
367398 self % layer_shape = shape (this_layer % output)
368399 type is (self_attention_layer)
369400 self % layer_shape = shape (this_layer % output)
401+ type is (layernorm_layer)
402+ self % layer_shape = shape (this_layer % output)
370403 type is (maxpool2d_layer)
371404 self % layer_shape = shape (this_layer % output)
372405 end select
@@ -425,6 +458,8 @@ elemental module function get_num_params(self) result(num_params)
425458 num_params = this_layer % get_num_params()
426459 type is (self_attention_layer)
427460 num_params = this_layer % get_num_params()
461+ type is (layernorm_layer)
462+ num_params = this_layer % get_num_params()
428463 class default
429464 error stop ' Unknown layer type.'
430465 end select
@@ -458,6 +493,8 @@ module function get_params(self) result(params)
458493 params = this_layer % get_params()
459494 type is (self_attention_layer)
460495 params = this_layer % get_params()
496+ type is (layernorm_layer)
497+ params = this_layer % get_params()
461498 class default
462499 error stop ' Unknown layer type.'
463500 end select
@@ -491,6 +528,8 @@ module function get_gradients(self) result(gradients)
491528 gradients = this_layer % get_gradients()
492529 type is (self_attention_layer)
493530 gradients = this_layer % get_gradients()
531+ type is (layernorm_layer)
532+ gradients = this_layer % get_gradients()
494533 class default
495534 error stop ' Unknown layer type.'
496535 end select
@@ -549,6 +588,9 @@ module subroutine set_params(self, params)
549588 type is (self_attention_layer)
550589 call this_layer % set_params(params)
551590
591+ type is (layernorm_layer)
592+ call this_layer % set_params(params)
593+
552594 type is (maxpool2d_layer)
553595 ! No parameters to set.
554596 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments