-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathnoisy_gradient.py
35 lines (28 loc) · 1.27 KB
/
noisy_gradient.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import tensorflow.keras as keras
from tensorflow.keras import backend as K
class NoisySGD(keras.optimizers.SGD):
def __init__(self, noise_eta=1, noise_gamma=4, **kwargs):
super(NoisySGD, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.noise_eta = K.variable(noise_eta, name='noise_eta')
self.noise_gamma = K.variable(noise_gamma, name='noise_gamma')
def get_gradients(self, loss, params):
grads = super(NoisySGD, self).get_gradients(loss, params)
# Add decayed gaussian noise
t = K.cast(self.iterations, K.dtype(grads[0]))
variance = self.noise_eta / ((1 + t) ** self.noise_gamma)
grads = [
grad + K.random_normal(
grad.shape,
mean=0.0,
stddev=K.sqrt(variance),
dtype=K.dtype(grads[0])
)
for grad in grads
]
return grads
def get_config(self):
config = {'noise_eta': float(K.get_value(self.noise_eta)),
'noise_gamma': float(K.get_value(self.noise_gamma))}
base_config = super(NoisySGD, self).get_config()
return dict(list(base_config.items()) + list(config.items()))