1313 use nf_linear2d_layer, only: linear2d_layer
1414 use nf_self_attention_layer, only: self_attention_layer
1515 use nf_embedding_layer, only: embedding_layer
16+ use nf_layernorm_layer, only: layernorm_layer
1617 use nf_optimizers, only: optimizer_base_type
1718
1819contains
@@ -47,7 +48,7 @@ pure module subroutine backward_1d(self, previous, gradient)
4748
4849 type is (flatten_layer)
4950
50- ! Upstream layers permitted: input2d, input3d, conv2d, maxpool2d
51+ ! Upstream layers permitted: input2d, input3d, conv2d, layernorm, maxpool2d
5152 select type (prev_layer = > previous % p)
5253 type is (input2d_layer)
5354 call this_layer % backward(prev_layer % output, gradient)
@@ -63,6 +64,8 @@ pure module subroutine backward_1d(self, previous, gradient)
6364 call this_layer % backward(prev_layer % output, gradient)
6465 type is (embedding_layer)
6566 call this_layer % backward(prev_layer % output, gradient)
67+ type is (layernorm_layer)
68+ call this_layer % backward(prev_layer % output, gradient)
6669 end select
6770
6871 end select
@@ -89,6 +92,8 @@ pure module subroutine backward_2d(self, previous, gradient)
8992 call this_layer % backward(prev_layer % output, gradient)
9093 type is (self_attention_layer)
9194 call this_layer % backward(prev_layer % output, gradient)
95+ type is (layernorm_layer)
96+ call this_layer % backward(prev_layer % output, gradient)
9297 end select
9398
9499 type is (self_attention_layer)
@@ -102,8 +107,18 @@ pure module subroutine backward_2d(self, previous, gradient)
102107 call this_layer % backward(prev_layer % output, gradient)
103108 type is (self_attention_layer)
104109 call this_layer % backward(prev_layer % output, gradient)
110+ type is (layernorm_layer)
111+ call this_layer % backward(prev_layer % output, gradient)
105112 end select
106113
114+ type is (layernorm_layer)
115+
116+ select type (prev_layer = > previous % p)
117+ type is (linear2d_layer)
118+ call this_layer % backward(prev_layer % output, gradient)
119+ type is (self_attention_layer)
120+ call this_layer % backward(prev_layer % output, gradient)
121+ end select
107122 end select
108123
109124 end subroutine backward_2d
@@ -241,6 +256,8 @@ module subroutine forward(self, input)
241256 call this_layer % forward(prev_layer % output)
242257 type is (linear2d_layer)
243258 call this_layer % forward(prev_layer % output)
259+ type is (layernorm_layer)
260+ call this_layer % forward(prev_layer % output)
244261 end select
245262
246263 type is (reshape3d_layer)
@@ -257,7 +274,7 @@ module subroutine forward(self, input)
257274
258275 type is (linear2d_layer)
259276
260- ! Upstream layers permitted: input2d, linear2d
277+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
261278 select type (prev_layer = > input % p)
262279 type is (input2d_layer)
263280 call this_layer % forward(prev_layer % output)
@@ -267,11 +284,13 @@ module subroutine forward(self, input)
267284 call this_layer % forward(prev_layer % output)
268285 type is (self_attention_layer)
269286 call this_layer % forward(prev_layer % output)
287+ type is (layernorm_layer)
288+ call this_layer % forward(prev_layer % output)
270289 end select
271290
272291 type is (self_attention_layer)
273292
274- ! Upstream layers permitted: input2d, linear2d
293+ ! Upstream layers permitted: input2d, linear2d, self_attention, layernorm
275294 select type (prev_layer = > input % p)
276295 type is (input2d_layer)
277296 call this_layer % forward(prev_layer % output)
@@ -281,6 +300,18 @@ module subroutine forward(self, input)
281300 call this_layer % forward(prev_layer % output)
282301 type is (self_attention_layer)
283302 call this_layer % forward(prev_layer % output)
303+ type is (layernorm_layer)
304+ call this_layer % forward(prev_layer % output)
305+ end select
306+
307+ type is (layernorm_layer)
308+
309+ ! Upstream layers permitted: linear2d, self_attention
310+ select type (prev_layer = > input % p)
311+ type is (linear2d_layer)
312+ call this_layer % forward(prev_layer % output)
313+ type is (self_attention_layer)
314+ call this_layer % forward(prev_layer % output)
284315 end select
285316
286317 end select
@@ -324,6 +355,8 @@ pure module subroutine get_output_2d(self, output)
324355 allocate (output, source= this_layer % output)
325356 type is (self_attention_layer)
326357 allocate (output, source= this_layer % output)
358+ type is (layernorm_layer)
359+ allocate (output, source= this_layer % output)
327360 class default
328361 error stop ' 2-d output can only be read from an input2d or linear2d layer.'
329362
@@ -367,8 +400,8 @@ impure elemental module subroutine init(self, input)
367400 call this_layer % init(input % layer_shape)
368401 end select
369402
370- ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d, or
371- ! self_attention layers is not known until we receive an input layer.
403+ ! The shape of conv2d, dropout, flatten, linear2d, maxpool2d,
404+ ! self_attention or layernorm layers is not known until we receive an input layer.
372405 select type (this_layer = > self % p)
373406 type is (conv2d_layer)
374407 self % layer_shape = shape (this_layer % output)
@@ -380,6 +413,8 @@ impure elemental module subroutine init(self, input)
380413 self % layer_shape = shape (this_layer % output)
381414 type is (self_attention_layer)
382415 self % layer_shape = shape (this_layer % output)
416+ type is (layernorm_layer)
417+ self % layer_shape = shape (this_layer % output)
383418 type is (maxpool2d_layer)
384419 self % layer_shape = shape (this_layer % output)
385420 end select
@@ -440,6 +475,8 @@ elemental module function get_num_params(self) result(num_params)
440475 num_params = this_layer % get_num_params()
441476 type is (embedding_layer)
442477 num_params = this_layer % get_num_params()
478+ type is (layernorm_layer)
479+ num_params = this_layer % get_num_params()
443480 class default
444481 error stop ' Unknown layer type.'
445482 end select
@@ -475,6 +512,8 @@ module function get_params(self) result(params)
475512 params = this_layer % get_params()
476513 type is (embedding_layer)
477514 params = this_layer % get_params()
515+ type is (layernorm_layer)
516+ params = this_layer % get_params()
478517 class default
479518 error stop ' Unknown layer type.'
480519 end select
@@ -510,6 +549,8 @@ module function get_gradients(self) result(gradients)
510549 gradients = this_layer % get_gradients()
511550 type is (embedding_layer)
512551 gradients = this_layer % get_gradients()
552+ type is (layernorm_layer)
553+ gradients = this_layer % get_gradients()
513554 class default
514555 error stop ' Unknown layer type.'
515556 end select
@@ -570,6 +611,9 @@ module subroutine set_params(self, params)
570611 type is (embedding_layer)
571612 call this_layer % set_params(params)
572613
614+ type is (layernorm_layer)
615+ call this_layer % set_params(params)
616+
573617 type is (maxpool2d_layer)
574618 ! No parameters to set.
575619 write (stderr, ' (a)' ) ' Warning: calling set_params() ' &
0 commit comments