-
Notifications
You must be signed in to change notification settings - Fork 8
/
grda_plaidml.py
43 lines (38 loc) · 1.7 KB
/
grda_plaidml.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
from keras.optimizers import Optimizer
from keras.legacy import interfaces
from keras import backend as K
class GRDA(Optimizer):
"""GRDA optimizer.
"""
def __init__(self, lr=0.01, c=0., mu=0.7, **kwargs):
super(GRDA, self).__init__(**kwargs)
with K.name_scope(self.__class__.__name__):
self.iterations = K.variable(0, dtype='int64', name='iterations')
self.lr = K.variable(lr, name='lr') # lr
self.mu = K.variable(mu, name='mu') # mu
if c==0:
raise ValueError("c = 0 is equivalent to SGD. Please use SGD.")
self.c = K.variable(c, name='c') # c
print("lr = ", lr, ", c=", c, ", mu = ", mu)
@interfaces.legacy_get_updates_support
def get_updates(self, loss, params):
grads = self.get_gradients(loss, params)
accumulators = [K.variable(value = K.get_value(p), dtype='float32') for p in params]
self.updates = [K.update_add(self.iterations, 1)]
lr = self.lr
mu = self.mu
c = self.c
l1 = c * K.pow(lr, 0.5 + mu) * K.pow(K.cast(self.iterations, K.floatx()) + 1, mu)
for p, g, a in zip(params, grads, accumulators):
new_a = a - lr * g
self.updates.append(K.update(a, new_a))
new_p = K.softthreshold(new_a, l1)
self.updates.append(K.update(p, new_p))
return self.updates
def get_config(self):
config = {'lr': float(K.get_value(self.lr)),
'mu': float(K.get_value(self.mu)),
'c': float(K.get_value(self.c))
}
base_config = super(GRDA, self).get_config()
return dict(list(base_config.items()) + list(config.items()))