-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathL2_Normalization.py
61 lines (55 loc) · 2.33 KB
/
L2_Normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
'''
A custom Keras layer to perform L2-normalization.
Copyright (C) 2018 Pierluigi Ferrari
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
'''
from __future__ import division
import numpy as np
import keras.backend as K
from tensorflow.keras.layers import InputSpec, Layer
class L2Normalization(Layer):
'''
Performs L2 normalization on the input tensor with a learnable scaling parameter
as described in the paper "Parsenet: Looking Wider to See Better" (see references)
and as used in the original SSD model.
Arguments:
gamma_init (int): The initial scaling parameter. Defaults to 20 following the
SSD paper.
Input shape:
4D tensor of shape `(batch, channels, height, width)` if `dim_ordering = 'th'`
or `(batch, height, width, channels)` if `dim_ordering = 'tf'`.
Returns:
The scaled tensor. Same shape as the input tensor.
References:
http://cs.unc.edu/~wliu/papers/parsenet.pdf
'''
def __init__(self, gamma_init=20, **kwargs):
if K.image_data_format() == 'tf':
self.axis = 3
else:
self.axis = 1
self.gamma_init = gamma_init
super(L2Normalization, self).__init__(**kwargs)
def build(self, input_shape):
self.input_spec = [InputSpec(shape=input_shape)]
gamma = self.gamma_init * np.ones((input_shape[self.axis],))
self.gamma = K.variable(gamma, name='{}_gamma'.format(self.name))
#self.trainable_weights = [self.gamma]
super(L2Normalization, self).build(input_shape)
def call(self, x, mask=None):
output = K.l2_normalize(x, self.axis)
return output * self.gamma
def get_config(self):
config = {
'gamma_init': self.gamma_init
}
base_config = super(L2Normalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))