-
Notifications
You must be signed in to change notification settings - Fork 0
/
layer_normalization.py
101 lines (92 loc) · 4.15 KB
/
layer_normalization.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
import keras
from keras import backend as K
"""
Copy from https://github.com/CyberZHG/keras-layer-normalization/blob/master/keras_layer_normalization/layer_normalization.py
"""
class LayerNormalization(keras.layers.Layer):
def __init__(self,
center=True,
scale=True,
epsilon=None,
gamma_initializer='ones',
beta_initializer='zeros',
gamma_regularizer=None,
beta_regularizer=None,
gamma_constraint=None,
beta_constraint=None,
**kwargs):
"""Layer normalization layer
See: [Layer Normalization](https://arxiv.org/pdf/1607.06450.pdf)
:param center: Add an offset parameter if it is True.
:param scale: Add a scale parameter if it is True.
:param epsilon: Epsilon for calculating variance.
:param gamma_initializer: Initializer for the gamma weight.
:param beta_initializer: Initializer for the beta weight.
:param gamma_regularizer: Optional regularizer for the gamma weight.
:param beta_regularizer: Optional regularizer for the beta weight.
:param gamma_constraint: Optional constraint for the gamma weight.
:param beta_constraint: Optional constraint for the beta weight.
:param kwargs:
"""
super(LayerNormalization, self).__init__(**kwargs)
self.supports_masking = True
self.center = center
self.scale = scale
if epsilon is None:
epsilon = K.epsilon() * K.epsilon()
self.epsilon = epsilon
self.gamma_initializer = keras.initializers.get(gamma_initializer)
self.beta_initializer = keras.initializers.get(beta_initializer)
self.gamma_regularizer = keras.regularizers.get(gamma_regularizer)
self.beta_regularizer = keras.regularizers.get(beta_regularizer)
self.gamma_constraint = keras.constraints.get(gamma_constraint)
self.beta_constraint = keras.constraints.get(beta_constraint)
self.gamma, self.beta = None, None
def get_config(self):
config = {
'center': self.center,
'scale': self.scale,
'epsilon': self.epsilon,
'gamma_initializer': keras.initializers.serialize(self.gamma_initializer),
'beta_initializer': keras.initializers.serialize(self.beta_initializer),
'gamma_regularizer': keras.regularizers.serialize(self.gamma_regularizer),
'beta_regularizer': keras.regularizers.serialize(self.beta_regularizer),
'gamma_constraint': keras.constraints.serialize(self.gamma_constraint),
'beta_constraint': keras.constraints.serialize(self.beta_constraint),
}
base_config = super(LayerNormalization, self).get_config()
return dict(list(base_config.items()) + list(config.items()))
def compute_output_shape(self, input_shape):
return input_shape
def compute_mask(self, inputs, input_mask=None):
return input_mask
def build(self, input_shape):
self.input_spec = keras.engine.InputSpec(shape=input_shape)
shape = input_shape[-1:]
if self.scale:
self.gamma = self.add_weight(
shape=shape,
initializer=self.gamma_initializer,
regularizer=self.gamma_regularizer,
constraint=self.gamma_constraint,
name='gamma',
)
if self.center:
self.beta = self.add_weight(
shape=shape,
initializer=self.beta_initializer,
regularizer=self.beta_regularizer,
constraint=self.beta_constraint,
name='beta',
)
super(LayerNormalization, self).build(input_shape)
def call(self, inputs, training=None):
mean = K.mean(inputs, axis=-1, keepdims=True)
variance = K.mean(K.square(inputs - mean), axis=-1, keepdims=True)
std = K.sqrt(variance + self.epsilon)
outputs = (inputs - mean) / std
if self.scale:
outputs *= self.gamma
if self.center:
outputs += self.beta
return outputs