/
shampoo_optimizer.py
398 lines (346 loc) · 18.5 KB
/
shampoo_optimizer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
"""The Shampoo Optimizer.
For details, see https://arxiv.org/abs/1802.09568
"""
from matrix_square_root_power import *
def GetParam(var, timestep):
if callable(var):
return var(timestep)
else:
return var
class ShampooOptimizer(optimizer.Optimizer):
"""The Shampoo Optimizer
Variant of Adagrad using one preconditioner matrix per variable dimension.
For details, see https://arxiv.org/abs/1802.09568
gbar is time-weighted accumulated gradient:
gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
mat_gbar is time-weighted accumulated gradient square:
mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ mat_gbar_weight[t] * gg_j[t]
where if g[t] = g_abcd then gg_a[t] = g_abcd g_a'bcd (Einstein notation)
Update rule:
w[t+1] = w[t] - learning_rate[t] * Prod_j mat_gbar_j[t]^(-alpha/n) gbar[t]
Again, mat_gbar_j[t]^(-alpha) gbar[t] is a tensor contraction along the
j'th dimension of gbar[t] with the first dimension of
mat_gbar_j[t]^(-alpha/n), where alpha is a hyperparameter,
and n = rank of the variable.
Prod_j represents doing this contraction for all j in 0..n-1.
Typically learning_rate is constant, but could be time dependent by passing
a lambda function that depends on step.
"""
def __init__(self,
global_step=0,
max_matrix_size=768,
gbar_decay=0.0,
gbar_weight=1.0,
mat_gbar_decay=1.0,
mat_gbar_weight=1.0,
learning_rate=1.0,
svd_interval=1,
precond_update_interval=1,
epsilon=1e-4,
alpha=0.5,
use_iterative_root=False,
use_locking=False,
name="Shampoo"):
"""Default values of the various hyper-parameters.
gbar_decay, gbar_weight etc. can be a float or a time varying parameter.
For time-varying parameters use e.g. "lambda T: T / (T + 1.0)"
where the expression in the lambda is a tensorflow expression
Args:
global_step: tensorflow variable indicating the step.
max_matrix_size: We do not perform SVD for matrices larger than this.
gbar_decay:
gbar_weight: Used to update gbar:
gbar[t] = gbar_decay[t] * gbar[t-1] + gbar_weight[t] * g[t]
mat_gbar_decay:
mat_gbar_weight: Used to update mat_gbar:
mat_gbar_j[t] = mat_gbar_decay[t] * mat_gbar_j[t-1]
+ mat_gbar_weight[t] * gg_j[t]
learning_rate: Similar to SGD
svd_interval: We should do SVD after this many steps. Default = 1, i.e.
every step. Usually 20 leads to no loss of accuracy, and
50 or 100 is also OK. May also want more often early,
and less often later - set in caller as for example:
"svd_interval = lambda(T): tf.cond(
T < 2000, lambda: 20.0, lambda: 1000.0)"
precond_update_interval: We should update the preconditioners after
this many steps. Default = 1. Usually less than
svd_interval.
epsilon: epsilon * I_n is added to each mat_gbar_j for stability for
non-diagonal version of shampoo.
alpha: total power of the preconditioners.
use_iterative_root: should the optimizer use SVD (faster) or the
iterative root method (for TPU) for finding the
roots of PSD matrices.
use_locking:
name: name of optimizer.
"""
super(ShampooOptimizer, self).__init__(use_locking, name)
self._global_step = math_ops.cast(global_step, dtypes.float32)
self._max_matrix_size = max_matrix_size
self._gbar_decay = gbar_decay
self._gbar_weight = gbar_weight
self._mat_gbar_decay = mat_gbar_decay
self._mat_gbar_weight = mat_gbar_weight
self._learning_rate = learning_rate
self._svd_interval = svd_interval
self._precond_update_interval = precond_update_interval
self._epsilon = epsilon
self._alpha = alpha
self._use_iterative_root = use_iterative_root
self._name = name
def _create_slots(self, var_list):
for v in var_list:
with ops.colocate_with(v):
_ = self._zeros_slot(v, "gbar", self._name)
shape = np.array(v.get_shape())
for i, d in enumerate(shape):
d_tensor = ops.convert_to_tensor(d)
if d <= self._max_matrix_size:
mat_g_init = array_ops.zeros_like(linalg_ops.eye(d_tensor))
if self._svd_interval > 1:
_ = self._get_or_make_slot(v, linalg_ops.eye(d_tensor),
"H_" + str(i), self._name)
else:
mat_g_init = array_ops.zeros([d_tensor])
_ = self._get_or_make_slot(v, mat_g_init, "Gbar_" + str(i),
self._name)
def _resource_apply_dense(self, grad, var):
return self._apply_dense(grad, var)
def _apply_dense(self, grad, var):
return self._apply_gradient(grad, var)
def _resource_apply_sparse(self, grad_values, var, grad_indices):
return self._apply_sparse_shared(grad_values, grad_indices, var)
def _apply_sparse(self, grad, var):
return self._apply_sparse_shared(grad.values, grad.indices, var)
def _apply_sparse_shared(self, grad_values, grad_indices, var):
if var.get_shape()[0] <= self._max_matrix_size or self._gbar_decay != 0.0:
# The dimension is small enough, we can make the variable dense and
# do a dense update
dense_grad = array_ops.scatter_nd(
array_ops.expand_dims(grad_indices, axis=1), grad_values,
array_ops.shape(var, out_type=grad_indices.dtype))
return self._apply_gradient(dense_grad, var)
return self._apply_gradient(grad_values, var, grad_indices)
def _weighted_average(self, var, weight, weight_t, rest):
"""Computes exponential weighted average: var = weight_t * var + rest.
Important to ensure that var does not occur in rest, otherwise
we can get race conditions in a distributed setting.
Args:
var: variable to be updated
weight: parameter to be checked. If it is a constant, we can optimize.
weight_t: current value of parameter, used for weighting
rest: the remaining tensor to be added
Returns:
updated variable.
"""
if weight == 0.0:
return rest # no need to update var, we will never use it.
if weight == 1.0: # common case
return state_ops.assign_add(var, rest)
# The op below can cause race conditions in a distributed setting,
# since computing weight_t * var + rest can take some time, during
# which var may be set by another worker. To prevent this, it should
# be implemented as a C++ op.
return var.assign_add((weight_t - 1) * var + rest)
def _update_mat_g(self, mat_g, grad, axes, mat_gbar_decay,
mat_gbar_weight, i):
"""Updates the cumulative outer products of the gradients.
Args:
mat_g: the matrix to be updated
grad: the gradient of the variable
axes: a list of k-1 integers 0 to k-1, except i
mat_gbar_decay: constant for weighted average:
mat_g = mat_g * decay + grad * weight
mat_gbar_weight: constant for weighted average
i: index of dimension to be updated.
Returns:
updated mat_g = mat_g * mat_gbar_decay + grad_outer * mat_gbar_weight
In Einstein notation if i = 0: grad_outer_aa'= g_abcd g_a'bcd
thus grad_outer is a matrix d_i x d_i, where d_i is the size of the
i'th dimension of g.
Alternate view: If mat_i(grad) is the flattening of grad to a
d_i x (d_1d_2...d_{i-1}d_{i+1}...d_k) matrix, then
grad_outer = mat_i(grad) mat_i(grad).transpose
"""
grad_outer = math_ops.tensordot(grad, grad, axes=(axes, axes),
name="grad_outer_" + str(i))
return self._weighted_average(mat_g, self._mat_gbar_decay, mat_gbar_decay,
mat_gbar_weight * grad_outer)
def _compute_power_svd(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name):
"""Computes mat_h = mat_g^alpha using svd. mat_g is a symmetric PSD matrix.
Args:
var: the variable we are updating.
mat_g: the symmetric PSD matrix whose power it to be computed
mat_g_size: size of mat_g
alpha: a real number
mat_h_slot_name: name of slot to store the power, if needed.
Returns:
mat_h = mat_g^alpha
Stores mat_h in the appropriate slot, if it exists.
Note that mat_g is PSD. So we could use linalg_ops.self_adjoint_eig.
"""
if mat_g_size == 1:
mat_h = math_ops.pow(mat_g + self._epsilon, alpha)
else:
damping = self._epsilon * linalg_ops.eye(
math_ops.cast(mat_g_size, dtypes.int32))
diag_d, mat_u, mat_v = linalg_ops.svd(mat_g + damping, full_matrices=True)
mat_h = math_ops.matmul(
mat_v * math_ops.pow(math_ops.maximum(diag_d, self._epsilon), alpha),
array_ops.transpose(mat_u))
if mat_h_slot_name is not None:
return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
return mat_h
def _compute_power_iter(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name,
iter_count=100, epsilon=1e-6):
"""Computes mat_g^alpha, where alpha = -1/p, p a positive integer."""
mat_g_sqrt = matrix_functions.matrix_square_root(mat_g, mat_g_size,
iter_count, self._epsilon)
mat_h = matrix_functions.matrix_inverse_pth_root(
mat_g_sqrt,
mat_g_size,
2 * alpha,
iter_count,
epsilon,
ridge_epsilon=0.0)
if mat_h_slot_name is not None:
return state_ops.assign(self.get_slot(var, mat_h_slot_name), mat_h)
return mat_h
def _compute_power(self, var, mat_g, mat_g_size, alpha, mat_h_slot_name=None):
"""Just a switch between the iterative power vs svd."""
with ops.name_scope("matrix_iterative_power"):
if self._use_iterative_root:
return self._compute_power_iter(var, mat_g, mat_g_size, alpha,
mat_h_slot_name)
else:
return self._compute_power_svd(var, mat_g, mat_g_size, alpha,
mat_h_slot_name)
def _apply_gradient(self, grad, var, indices=None):
"""The main function to update a variable.
Args:
grad: A Tensor containing gradient to apply.
var: A Tensor containing the variable to update.
indices: An array of integers, for sparse update.
Returns:
Updated variable var = var - learning_rate * preconditioner * grad
If the gradient is dense, var and grad have the same shape.
If the update is sparse, then the first dimension of the gradient and var
may differ, others are all the same. In this case the indices array
provides the set of indices of the variable which are to be updated with
each row of the gradient.
"""
global_step = self._global_step + 1
# Update accumulated weighted average of gradients
gbar = self.get_slot(var, "gbar")
gbar_decay_t = GetParam(self._gbar_decay, global_step)
gbar_weight_t = GetParam(self._gbar_weight, global_step)
if indices is not None:
# Note - the sparse update is not easily implemented, since the
# algorithm needs all indices of gbar to be updated
# if mat_gbar_decay != 1 or mat_gbar_decay != 0.
# One way to make mat_gbar_decay = 1 is by rescaling.
# If we want the update:
# G_{t+1} = a_{t+1} G_t + b_{t+1} w_t
# define:
# r_{t+1} = a_{t+1} * r_t
# h_t = G_t / r_t
# Then:
# h_{t+1} = h_t + (b_{t+1} / r_{t+1}) * w_t
# So we get the mat_gbar_decay = 1 as desired.
# We can implement this in a future version as needed.
# However we still need gbar_decay = 0, otherwise all indices
# of the variable will need to be updated.
if self._gbar_decay != 0.0:
tf_logging.warning("Not applying momentum for variable: %s" % var.name)
gbar_updated = grad
else:
gbar_updated = self._weighted_average(gbar, self._gbar_decay,
gbar_decay_t,
gbar_weight_t * grad)
# Update the preconditioners and compute the preconditioned gradient
shape = var.get_shape()
mat_g_list = []
for i in range(len(shape)):
mat_g_list.append(self.get_slot(var, "Gbar_" + str(i)))
mat_gbar_decay_t = GetParam(self._mat_gbar_decay, global_step)
mat_gbar_weight_t = GetParam(self._mat_gbar_weight, global_step)
preconditioned_grad = gbar_updated
v_rank = len(mat_g_list)
neg_alpha = - GetParam(self._alpha, global_step) / v_rank
svd_interval = GetParam(self._svd_interval, global_step)
precond_update_interval = GetParam(self._precond_update_interval,
global_step)
for i, mat_g in enumerate(mat_g_list):
# axes is the list of indices to reduce - everything but the current i.
axes = list(range(i)) + list(range(i+1, v_rank))
if shape[i] <= self._max_matrix_size:
# If the tensor size is sufficiently small perform full Shampoo update
# Note if precond_update_interval > 1 and mat_gbar_decay_t != 1, this
# is not strictly correct. However we will use it for now, and
# fix if needed. (G_1 = aG + bg ==> G_n = a^n G + (1+a+..+a^{n-1})bg)
# pylint: disable=g-long-lambda,cell-var-from-loop
mat_g_updated = control_flow_ops.cond(
math_ops.mod(global_step, precond_update_interval) < 1,
lambda: self._update_mat_g(mat_g, grad, axes, mat_gbar_decay_t,
mat_gbar_weight_t * precond_update_interval, i), lambda: mat_g)
#mat_g_updated = mat_g_updated / float(shape[i].value)
mat_g_updated = mat_g_updated / float(shape[i])
if self._svd_interval == 1:
mat_h = self._compute_power(var, mat_g_updated, shape[i], neg_alpha)
else:
mat_h = control_flow_ops.cond(
math_ops.mod(global_step, svd_interval) < 1,
lambda: self._compute_power(var, mat_g_updated, shape[i],
neg_alpha, "H_" + str(i)),
lambda: self.get_slot(var, "H_" + str(i)))
# mat_h is a square matrix of size d_i x d_i
# preconditioned_grad is a d_i x ... x d_n x d_0 x ... d_{i-1} tensor
# After contraction with a d_i x d_i tensor
# it becomes a d_{i+1} x ... x d_n x d_0 x ... d_i tensor
# (the first dimension is contracted out, and the second dimension of
# mat_h is appended). After going through all the indices, it becomes
# a d_0 x ... x d_n tensor again.
preconditioned_grad = math_ops.tensordot(preconditioned_grad, mat_h,
axes=([0], [0]),
name="precond_" + str(i))
else:
# Tensor size is too large -- perform diagonal Shampoo update
# Only normalize non-vector cases.
if axes:
#normalizer = 1.0 if indices is not None else float(shape[i].value)
normalizer = 1.0 if indices is not None else float(shape[i])
grad_outer = math_ops.reduce_sum(grad * grad, axis=axes) / normalizer
else:
grad_outer = grad * grad
if i == 0 and indices is not None:
assert self._mat_gbar_decay == 1.0
mat_g_updated = state_ops.scatter_add(mat_g, indices,
mat_gbar_weight_t * grad_outer)
mat_g_updated_slice = array_ops.gather(mat_g_updated, indices)
mat_h = array_ops.where(
math_ops.greater(mat_g_updated_slice, 0),
math_ops.pow(mat_g_updated_slice, neg_alpha),
array_ops.zeros_like(mat_g_updated_slice))
else:
mat_g_updated = self._weighted_average(mat_g,
self._mat_gbar_decay,
mat_gbar_decay_t,
mat_gbar_weight_t * grad_outer)
mat_h = array_ops.where(
math_ops.greater(mat_g_updated, 0),
math_ops.pow(mat_g_updated, neg_alpha),
array_ops.zeros_like(mat_g_updated))
# Need to do the transpose to ensure that the tensor becomes
# a d_{i+1} x ... x d_n x d_0 x ... d_i tensor as described above.
preconditioned_grad = array_ops.transpose(
preconditioned_grad, perm=list(range(1, v_rank)) + [0]) * mat_h
# Update the variable based on the Shampoo update
learning_rate_t = GetParam(self._learning_rate, global_step)
if indices is not None:
var_updated = state_ops.scatter_add(
var, indices, -learning_rate_t * preconditioned_grad)
else:
var_updated = state_ops.assign_sub(var,
learning_rate_t * preconditioned_grad)
return var_updated