-
Notifications
You must be signed in to change notification settings - Fork 1
/
Copy pathshampoo.py
2831 lines (2456 loc) · 109 KB
/
shampoo.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# coding=utf-8
# Copyright 2024 The Google Research Authors.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# An implementation of distributed Shampoo optimizer from:
#
# Scalable Second Order Optimization for Deep Learning
# Rohan Anil, Vineet Gupta, Tomer Koren, Kevin Regan, Yoram Singer
# Preprint Paper: https://arxiv.org/abs/2002.09018
#
# This implementation moves computation of inverse pth root back to the
# accelerator (if higher precision is available).
#
# Authors: Rohan Anil (rohananil at google dot com)
# Vineet Gupta (vineet at google dot com)
# James Lottes (jlottes at google dot com)
# Anudhyan Boral (anudhyan at google dot com)
#
"""Distributed Shampoo Implementation."""
import enum
import functools
import itertools
import logging
from typing import (Any, Callable, cast, List, NamedTuple, Optional, Sequence,
Tuple, TypeVar, Union)
import chex
from flax import struct
import jax
from jax import lax
from jax.experimental.sparse import linalg
import jax.numpy as jnp
import numpy as np
import optax
from quant_utils import QuantizedValue
# Dtype for inverse-pth root routine
# Switch to f64 if you have hardware that supports it. Enable the jax flag
# jax_enable_x64 for this to work, otherwise it will default to float32.
_MAT_INV_PTH_ROOT_DTYPE = jnp.float64
# Small epsilon to avoid divide by zero.
_EPSILON = 1e-25
def _default_zero_field():
return struct.field(
default_factory=functools.partial(jnp.array, 0, jnp.float32))
T = TypeVar("T")
def _maybe_ix(ls, ix):
"""Return ls[ix] if not None else None."""
if ls is None:
return None
return ls[ix]
def _maybe(f):
"""Lifts f to Maybe monad; ie return None if first arg is."""
def wrap_f(x, *args, **kwargs):
if x is None:
return None
return f(x, *args, **kwargs)
return wrap_f
InversePthRootDiagnosticsSubtype = TypeVar(
"InversePthRootDiagnosticsSubtype", bound="InversePthRootDiagnostics")
@struct.dataclass
class InversePthRootDiagnostics:
"""Diagnostics for inverse p-th root iterative procedure.
Given an inverse pth root B = A^(-1/p), contains the average and
maximum diagonal and off diagonal absolute entrywise errors between
(B^p A) and I.
"""
max_diag_error: chex.Array = _default_zero_field()
avg_diag_error: chex.Array = _default_zero_field()
max_off_diag_error: chex.Array = _default_zero_field()
avg_off_diag_error: chex.Array = _default_zero_field()
p: chex.Array = _default_zero_field()
@classmethod
def create(cls,
pth_inverse_root, matrix,
p):
"""Generates a diagnostics struct from (-1/p) root result."""
mat_m = jnp.matmul(
mat_power(pth_inverse_root, p),
matrix,
precision=jax.lax.Precision.HIGHEST)
num_off_diag_entries = mat_m.size - jnp.diag(mat_m).size
diag_error = jnp.abs(jnp.diag(mat_m) - 1).astype(jnp.float32)
off_diag_error = jnp.abs(mat_m - jnp.diag(jnp.diag(mat_m))).astype(
jnp.float32)
return cls(
max_diag_error=jnp.max(diag_error).astype(jnp.float32),
avg_diag_error=jnp.mean(diag_error).astype(jnp.float32),
max_off_diag_error=jnp.max(off_diag_error).astype(jnp.float32),
avg_off_diag_error=(jnp.sum(off_diag_error) /
num_off_diag_entries).astype(jnp.float32),
p=jnp.array(p, jnp.float32))
LOBPCGDiagnosticsSubtype = TypeVar(
"LOBPCGDiagnosticsSubtype", bound="LOBPCGDiagnostics")
@struct.dataclass
class LOBPCGDiagnostics:
"""Diagnostics for iterative LOBPCG eigenvalue routine.
Contains consistency error for LOBPCG eigenvalue routine, which
refers to |A v - lambda v| / (lambda + |A v|) for a proposed eigenpair
(v, lambda). This metics dataclass retains consistency error
and other useful LOBPCG values.
"""
lobpcg_iters: chex.Array = _default_zero_field()
max_consistency_error: chex.Array = _default_zero_field()
avg_consistency_error: chex.Array = _default_zero_field()
# Average of absolute value of off-diagonal of V^T V for eigenvalues V.
avg_orthogonality_error: chex.Array = _default_zero_field()
max_eigenvalue: chex.Array = _default_zero_field()
min_eigenvalue: chex.Array = _default_zero_field()
num_topk_eigenvectors: chex.Array = _default_zero_field()
@classmethod
def create(cls, matrix,
eigvals, eigvecs,
lobpcg_iters):
"""Generates LOBPCG diagnostics from the result of the routine."""
num_topk = len(eigvals)
num_off_diag = num_topk * (num_topk - 1)
precision = jax.lax.Precision.HIGHEST
mat_eigvecs = matrix.dot(eigvecs, precision=precision)
consistency_error_unnormalized = jnp.linalg.norm(
mat_eigvecs - eigvals * eigvecs, ord=2, axis=0)
normalization = jnp.linalg.norm(mat_eigvecs, ord=2, axis=0) + eigvals
consistency_error = consistency_error_unnormalized / normalization
orthogonality_error = eigvecs.T.dot(eigvecs, precision=precision)
orthogonality_error -= jnp.diag(jnp.diag(orthogonality_error))
return cls(
lobpcg_iters=jnp.array(lobpcg_iters, jnp.float32),
max_consistency_error=jnp.max(consistency_error).astype(jnp.float32),
avg_consistency_error=jnp.mean(consistency_error).astype(jnp.float32),
avg_orthogonality_error=(jnp.sum(orthogonality_error) /
num_off_diag).astype(jnp.float32),
max_eigenvalue=jnp.max(eigvals).astype(jnp.float32),
min_eigenvalue=jnp.min(eigvals).astype(jnp.float32),
num_topk_eigenvectors=jnp.array(num_topk, jnp.float32),
)
@struct.dataclass
class TrainingMetrics:
"""Diagnostic metrics from training."""
# Error for inverse-pth roots.
inverse_pth_root_errors: chex.Array = _default_zero_field()
# Iteration count for inverse-pth roots.
inverse_pth_root_iters: chex.Array = _default_zero_field()
# If final iteration error increases sufficiently, iteration terminates early.
# This field records the ratio of the final iteration error.
final_error_ratio: chex.Array = _default_zero_field()
# Max eigen value from either the power iteration or from LOBPCG.
max_eigen_value: chex.Array = _default_zero_field()
# Total retries of inverse pth root iterative method.
total_retries: chex.Array = _default_zero_field()
lobpcg_diagnostics: LOBPCGDiagnostics = struct.field(
default_factory=LOBPCGDiagnostics)
# Rich matrix entrywise error diagnostics, if enabled.
inverse_pth_root_diagnostics: InversePthRootDiagnostics = struct.field(
default_factory=InversePthRootDiagnostics)
# Diagnostics applied to the conditioned p-th root problem, after top
# eigenvectors are removed, if LOBPCG is being applied.
conditioned_inverse_pth_root_diagnostics: InversePthRootDiagnostics = (
struct.field(default_factory=InversePthRootDiagnostics))
# TODO(rohananil): Add more important metrics to track during training.
# Per parameter optimizer state used in data-parallel training.
class ParameterStats(NamedTuple):
"""State associated to each parameter of the model being trained."""
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
statistics: Optional[List[Any]] # Statistics (QuantizedValue, chex.Array)
preconditioners: List[Any] # Preconditioners (QuantizedValue, chex.Array)
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
training_metrics: Union[TrainingMetrics, optax.MaskedNode] # Optional.
# For training extremely large model; We keep a global state with a concatenated
# statistics and preconditioner states for all vars. This is so that we can
# annotate the leading axis to be sharded to save memory at the cost of
# communication.
@struct.dataclass
class GlobalShardedParameterStats:
statistics: chex.Array # Statistics
preconditioners: chex.Array # Preconditioners
exponents: chex.Array # exponents
# These are per-parameter local states; All statistics here mirror the parameter
# Thus the sharding is copied over from the param specification.
@struct.dataclass
class LocalShardedParameterStats:
"""State associated to each parameter of the model being trained."""
diagonal_statistics: QuantizedValue # Accumulator for diagonal preconditioner
diagonal_momentum: QuantizedValue # Momentum for the diagonal preconditioner
momentum: QuantizedValue # Momentum for the shampoo preconditioner
training_metrics: Union[TrainingMetrics, optax.MaskedNode]
index_start: Union[np.int32, int] = struct.field(
pytree_node=False) # Index into global statistics array
sizes: Any = struct.field(pytree_node=False) # Sizes of the statistics.
def default_training_metrics(
):
"""Create a default TrainingMetrics."""
return TrainingMetrics()
def init_training_metrics(
num_statistics,
generate_training_metrics,
):
"""Initialize TrainingMetrics, masked if disabled."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree.map(
functools.partial(jnp.repeat, repeats=num_statistics),
default_training_metrics(
))
def init_training_metrics_shapes(
num_statistics,
generate_training_metrics,
):
"""Initialize training metrics shape/dtype."""
seed = init_training_metrics(
num_statistics,
generate_training_metrics,
)
return jax.tree.map(lambda arr: [list(arr.shape), arr.dtype], seed)
def init_training_metrics_pspec(
generate_training_metrics,
):
"""Initialize training metrics partition specification."""
if not generate_training_metrics:
return optax.MaskedNode()
return jax.tree.map(
lambda _: jax.sharding.PartitionSpec(),
default_training_metrics(
))
class ShardedShampooStats(NamedTuple):
"""Shampoo state in sharded mode."""
global_stats: Any
local_stats: Any
class ShampooState(NamedTuple):
count: chex.Array
stats: Any
class InitFnState(NamedTuple):
init_fn: Any
pspec_fn: Any
shape_and_dtype_fn: Any
class GraftingType(enum.IntEnum):
NONE = 0
SGD = 1
ADAGRAD = 2
RMSPROP = 3
RMSPROP_NORMALIZED = 4
SQRT_N = 5
ADAGRAD_NORMALIZED = 6
class PreconditionerType(enum.IntEnum):
# Default, computes preconditioner for each dim
ALL = 1
# One sided Shampoo, in this cases only on input dim.
# Assumes last dim is always the output dim and everything else input dim.
INPUT = 2
# One sided Shampoo, in this cases only on output dim.
# Assumes last dim is always the output dim and everything else input dim.
OUTPUT = 3
def power_iteration(
matrix,
num_iters = 100,
error_tolerance = 1e-6,
precision = lax.Precision.HIGHEST,
padding_start = None,
):
r"""Power iteration algorithm.
The power iteration algorithm takes a symmetric PSD matrix `A`, and produces
a scalar `\lambda` , which is the greatest (in absolute value) eigenvalue
of `A`, and a vector v, which is the corresponding eigenvector of `A`.
References:
[Wikipedia, 2021](https://en.wikipedia.org/wiki/Power_iteration)
Args:
matrix: the symmetric PSD matrix.
num_iters: Number of iterations.
error_tolerance: Iterative exit condition.
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
(best possible precision, slowest)
padding_start: if set, assumes rows and columns after padding_start are
zero.
Returns:
eigen vector, eigen value
"""
matrix_size = matrix.shape[-1]
def _iter_condition(state):
i, unused_v, unused_s, unused_s_v, run_step = state
return jnp.logical_and(i < num_iters, run_step)
def _iter_body(state):
"""One step of power iteration."""
i, new_v, s, s_v, unused_run_step = state
new_v = new_v / jnp.linalg.norm(new_v)
s_v = jnp.einsum("ij,j->i", matrix, new_v, precision=precision)
s_new = jnp.einsum("i,i->", new_v, s_v, precision=precision)
return (i + 1, s_v, s_new, s_v,
jnp.greater(jnp.abs(s_new - s), error_tolerance))
# Figure out how to use step as seed for random.
v_0 = np.random.RandomState(1729).uniform(-1.0, 1.0,
matrix_size).astype(matrix.dtype)
v_0 = jnp.array(v_0)
if padding_start is not None:
v_0 *= (jnp.arange(len(v_0), dtype=jnp.int32) < padding_start)
init_state = tuple([0, v_0, jnp.zeros([], dtype=matrix.dtype), v_0, True])
_, v_out, s_out, _, _ = lax.while_loop(_iter_condition, _iter_body,
init_state)
v_out = v_out / jnp.linalg.norm(v_out)
return v_out, s_out
def mat_power(
mat_m,
p,
precision = lax.Precision.HIGHEST,
):
"""A simple matrix power method. M^p where p can be TracedValue."""
power = jnp.eye(mat_m.shape[0], dtype=_MAT_INV_PTH_ROOT_DTYPE)
def _iter_condition(state):
i, _, _ = state
return i > 0
def _iter_body(state):
i, power, mat = state
power = jax.lax.cond(i % 2 == 1,
lambda: jnp.matmul(mat, power, precision=precision),
lambda: power)
i //= 2
mat = jnp.matmul(mat, mat, precision=precision)
return i, power, mat
_, result, _ = lax.while_loop(_iter_condition, _iter_body, (p, power, mat_m))
return result
def _pth_root_difference(w, alpha, beta,
p):
"""Computes (w+alpha)^(-1/p)-(w+beta)^(-1/p)."""
a = w + alpha
b = w + beta
a_minus_b = alpha - beta
exp = -1 / p
def _stable_subtract(b, a_minus_b):
# Mathematically identical to the target expression, with (w+beta)^(-1/p)
# term factored out and w cancellation in the subtraction.
return (b**exp) * jnp.expm1(exp * jnp.log1p(a_minus_b / b))
return jnp.where(
# Choose the branch with the best log1p approximation.
jnp.abs(a_minus_b / b) < jnp.abs(a_minus_b / a),
-_stable_subtract(a, -a_minus_b),
_stable_subtract(b, a_minus_b))
def matrix_inverse_pth_root(
matrix,
p,
num_iters = 100,
ridge_epsilon = 1e-6,
error_tolerance = 1e-6,
precision = lax.Precision.HIGHEST,
relative_matrix_epsilon = True,
lobpcg_topk_precondition = 0,
lobpcg_max_iter = 0,
padding_start = None,
prev = None,
eigh=False,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses the Eigh or Coupled newton iterations algorithm for
the computation of a matrix's inverse pth root.
References:
[Functions of Matrices, Theory and Computation,
Nicholas J Higham, Pg 184, Eq 7.18](
https://epubs.siam.org/doi/book/10.1137/1.9780898717778)
Args:
matrix: the symmetric PSD matrix whose power it to be computed
p: exponent, for p a positive integer.
num_iters: Maximum number of iterations.
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
error_tolerance: Error indicator, useful for early termination.
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
(best possible precision, slowest)
relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
value when computing inverse-pth root.
lobpcg_topk_precondition: If nonzero, specifies the number of top
eigenvectors to subtract out before performing LOBPCG. Note this makes
relative_matrix_epsilon essentially free.
lobpcg_max_iter: Maximum iteration count for LOBPCG, defaults to
`lobpcg_topk_precondition`.
padding_start: If the input matrix was padded, then zeros out columns and
rows at the padding start.
prev: previous iteration's solution, zero-padded (unused)
eigh: If True, uses eigh for inverse-pth root computation.
Returns:
`(matrix + eps)^(-1/p)` and error metrics.
Note `eps` is not added to zeroed out padding rows and
columns. `eps` is just `ridge_epsilon` if
`relative_matrix_epsilon` is set to `False`, otherwise, it is the
ridge epsilon value scaled by the derived maximum eigenvalue of
the input matrix.
"""
if eigh:
return matrix_inverse_pth_root_eigh(matrix, p, ridge_epsilon,
error_tolerance, precision,
relative_matrix_epsilon, padding_start,
prev)
del prev
assert matrix.shape[0] == matrix.shape[1]
# We use _MAT_INV_PTH_ROOT_DTYPE for the matrix inverse pth root.
# Switch to f64 if you have hardware that supports it. Enable the jax flag
# jax_enable_x64 for this to work.
matrix_size = matrix.shape[0]
orig_dtype = matrix.dtype
matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
# Zero out padding in identity as well for convergence checks.
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
original_matrix = matrix
# Only used in lobpcg branches, but required by pytype.
eigvals, eigvecs, lobpcg_diagnostics = None, None, None
if lobpcg_topk_precondition > 0:
# TODO(vladf): reuse previous top-k as the initial search directions
pad_shape = (matrix_size - lobpcg_topk_precondition,
lobpcg_topk_precondition)
search_dirs = jnp.concatenate(
(jnp.eye(lobpcg_topk_precondition), jnp.zeros(pad_shape)), axis=0)
eigvals, eigvecs, lobpcg_iters = linalg.lobpcg_standard(
matrix, search_dirs,
lobpcg_topk_precondition if lobpcg_max_iter == 0 else lobpcg_max_iter)
lobpcg_diagnostics = LOBPCGDiagnostics.create(
matrix,
eigvals,
eigvecs,
lobpcg_iters,
)
# The minimal eigenvalue among top-k becomes the maximal one in the whole
# matrix after deflation.
deflation = eigvals - jnp.min(eigvals)
scaled_vecs = eigvecs * jnp.sqrt(deflation)
# Deflate out top eigenvectors to reduce matrix condition number.
matrix -= scaled_vecs.dot(
scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
if relative_matrix_epsilon:
if eigvals is not None:
max_ev = jnp.max(eigvals)
else:
# Only use power iteration if lobpcg wasn't already used to derive the
# top eigenvalue.
_, max_ev = power_iteration(
matrix=matrix,
num_iters=100,
error_tolerance=1e-6,
precision=precision,
padding_start=padding_start)
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, error_tolerance)
# Sometimes error increases after an iteration before decreasing and
# converging. 1.2 factor is used to bound the maximal allowed increase.
max_error_ratio = 1.2
def _iter_condition(state):
i, unused_mat_m, unused_mat_h, unused_old_mat_h, error, error_ratio = state
error_above_threshold = jnp.logical_and(error > error_tolerance,
error_ratio < max_error_ratio)
return jnp.logical_and(i < num_iters, error_above_threshold)
def _iter_body(state):
(i, mat_m, mat_h, unused_old_mat_h, error, unused_error_ratio) = state
mat_m_i = (1 - alpha) * identity + alpha * mat_m
new_mat_m = jnp.matmul(mat_power(mat_m_i, p), mat_m, precision=precision)
new_mat_h = jnp.matmul(mat_h, mat_m_i, precision=precision)
new_error = jnp.max(jnp.abs(new_mat_m - identity))
return (i + 1, new_mat_m, new_mat_h, mat_h, new_error, new_error / error)
if matrix_size == 1:
damped_matrix = matrix + ridge_epsilon
resultant_mat_h = damped_matrix**alpha
error = jnp.array(0, jnp.float32)
iters = 0
error_ratio = 0.0
else:
retry_loop_error_threshold = 0.05
num_tries = 6
init_outer_state = tuple([0, identity, 1000.0, 100, 1.0, True])
def _outer_iter_condition_fn(state):
i, _, _, _, _, iter_failed = state
return jnp.logical_and(iter_failed, i < num_tries)
def _outer_body_fn(state):
i, _, _, _, _, _ = state
# Update the epsilon based on the loop iteration.
damped_matrix = matrix + (ridge_epsilon * (10**i) * identity)
z = (1 + p) / (2 * jnp.linalg.norm(damped_matrix))
new_mat_m_0 = damped_matrix * z
new_error = jnp.max(jnp.abs(new_mat_m_0 - identity))
new_mat_h_0 = identity * jnp.power(z, 1.0 / p)
init_state = tuple(
[0, new_mat_m_0, new_mat_h_0, new_mat_h_0, new_error, 1.0])
iters, mat_m, mat_h, old_mat_h, error, error_ratio = lax.while_loop(
_iter_condition, _iter_body, init_state)
error = jnp.max(jnp.abs(mat_m - identity)).astype(jnp.float32)
is_converged = jnp.asarray(error_ratio < max_error_ratio, old_mat_h.dtype)
resultant_mat_h = is_converged * mat_h + (1 - is_converged) * old_mat_h
return (i + 1, resultant_mat_h, error, iters, error_ratio,
error > retry_loop_error_threshold)
total_retries, resultant_mat_h, error, iters, error_ratio, _ = jax.lax.while_loop(
_outer_iter_condition_fn, _outer_body_fn, init_outer_state)
conditioned_resultant_mat = resultant_mat_h
if lobpcg_topk_precondition > 0:
# Since we deflated the top eigenvectors prior to p-th root inverse,
# the resultant matrix has larger eigenvalues associated with those
# same eigenvectors, which we need to now re-deflate.
#
# Note that _pth_root_difference returns positive values for this
# particular argument ordering as min(eigvals) <= eigvals for the
# jnp.sqrt below.
pth_diff = _pth_root_difference(ridge_epsilon, jnp.min(eigvals), eigvals, p)
scaled_vecs = eigvecs * jnp.sqrt(pth_diff)
resultant_mat_h = conditioned_resultant_mat - scaled_vecs.dot(
scaled_vecs.T, precision=jax.lax.Precision.HIGHEST)
error_metrics = TrainingMetrics(
inverse_pth_root_errors=jnp.array(error, jnp.float32),
inverse_pth_root_iters=jnp.array(iters, jnp.float32),
final_error_ratio=jnp.array(error_ratio, jnp.float32),
max_eigen_value=jnp.array(max_ev, jnp.float32),
total_retries=jnp.array(total_retries, jnp.float32))
if lobpcg_topk_precondition > 0:
damped_matrix = matrix + (ridge_epsilon * (10**total_retries) * identity)
conditioned_diagnostics = InversePthRootDiagnostics.create(
conditioned_resultant_mat, damped_matrix, p)
unconditioned_damped_matrix = original_matrix + ridge_epsilon * identity
unconditioned_diagnostics = InversePthRootDiagnostics.create(
resultant_mat_h, unconditioned_damped_matrix, p)
# The max entrywise error in error_metrics.inverse_pth_root_errors refers
# to what was derived from the inverse pth root iteration, which with
# LOBPCG refers to the conditioned problem. Make sure to use the error
# from the unconditioned problem.
unconditional_errors = jnp.maximum(
unconditioned_diagnostics.max_diag_error,
unconditioned_diagnostics.max_off_diag_error)
error_metrics = error_metrics.replace(
inverse_pth_root_errors=unconditional_errors,
lobpcg_diagnostics=lobpcg_diagnostics,
conditioned_inverse_pth_root_diagnostics=conditioned_diagnostics,
inverse_pth_root_diagnostics=unconditioned_diagnostics,
)
if padding_start is not None:
# Occasionally, pure-padding matrices are handed to the inversion routine
# due to some TPU hosts not having the same number of preconditioning
# matrices.
resultant_mat_h = jnp.where(padding_start == 0, 0.0, resultant_mat_h)
error = jnp.where(padding_start == 0, 0.0,
error_metrics.inverse_pth_root_errors)
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
resultant_mat_h = jnp.asarray(resultant_mat_h, orig_dtype)
return resultant_mat_h, error_metrics
def matrix_inverse_pth_root_eigh(
matrix,
p,
ridge_epsilon = 1e-6,
error_tolerance = 1e-6,
precision = lax.Precision.HIGHEST,
relative_matrix_epsilon = True,
padding_start = None,
prev = None,
):
"""Computes `matrix^(-1/p)`, where `p` is a positive integer.
This function uses eigh for the computation of a matrix's inverse pth
root.
Args:
matrix: the symmetric PSD matrix whose power it to be computed
p: exponent, for p a positive integer.
ridge_epsilon: Ridge epsilon added to make the matrix positive definite.
error_tolerance: Error indicator, useful for early termination.
precision: precision XLA related flag, the available options are: a)
lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
(best possible precision, slowest)
relative_matrix_epsilon: Whether to use relative epsilon to the max eigen
value when computing inverse-pth root.
padding_start: If the input matrix was padded, then zeros out columns and
rows at the padding start.
prev: previous iteration's solution, zero-padded (unused)
Returns:
`(matrix + eps)^(-1/p)` and error metrics.
Note `eps` is not added to zeroed out padding rows and
columns. `eps` is just `ridge_epsilon` if
`relative_matrix_epsilon` is set to `False`, otherwise, it is the
ridge epsilon value scaled by the derived maximum eigenvalue of
the input matrix.
"""
del prev
assert matrix.shape[0] == matrix.shape[1]
matrix_size = matrix.shape[0]
orig_dtype = matrix.dtype
matrix = matrix.astype(_MAT_INV_PTH_ROOT_DTYPE)
alpha = jnp.asarray(-1.0 / p, _MAT_INV_PTH_ROOT_DTYPE)
identity = jnp.eye(matrix_size, dtype=_MAT_INV_PTH_ROOT_DTYPE)
if padding_start is not None:
ix = (jnp.arange(matrix_size, dtype=jnp.int32) < padding_start).astype(
matrix.dtype)
matrix *= ix[jnp.newaxis, :]
matrix *= ix[:, jnp.newaxis]
identity *= ix
if relative_matrix_epsilon:
_, max_ev = power_iteration(
matrix=matrix,
num_iters=100,
error_tolerance=error_tolerance,
precision=precision,
padding_start=padding_start)
else:
# Use absolute matrix epsilon scaling otherwise.
max_ev = 1.0
ridge_epsilon = ridge_epsilon * jnp.maximum(max_ev, error_tolerance)
regularized_input = matrix + ridge_epsilon * identity
e, u = jnp.linalg.eigh(regularized_input)
# Due to padding, we may have to zero out eigenvalues.
if padding_start is not None:
e *= jnp.flip(ix)
mm = functools.partial(jnp.matmul, precision=precision)
inv_e = jnp.where(e == 0.0, 0.0,
jnp.power(jnp.maximum(e, ridge_epsilon), alpha))
val = mm(mm(u, jnp.diag(inv_e)), u.T)
root = u * jnp.sqrt(inv_e)
val = mm(root, root.T)
recovered_e = mm(u.T, mm(regularized_input, u))
eig_error = recovered_e - jnp.diag(e)
if padding_start is not None:
eig_error *= jnp.flip(ix)
error = jnp.max(jnp.abs(eig_error))
error_metrics = TrainingMetrics(
inverse_pth_root_errors=jnp.array(error, jnp.float32))
if padding_start is not None:
val = jnp.where(padding_start == 0, 0.0, val)
error = jnp.where(padding_start == 0, 0.0,
error_metrics.inverse_pth_root_errors)
error_metrics = error_metrics.replace(inverse_pth_root_errors=error)
val = jnp.asarray(val, orig_dtype)
return val, error_metrics
def merge_small_dims(shape_to_merge, max_dim):
"""Merge small dimensions.
If there are some small dimensions, we collapse them:
e.g. [1, 2, 512, 1, 2048, 1, 3, 4] --> [1024, 2048, 12] if max_dim = 1024
[1, 2, 768, 1, 2048] --> [2, 768, 2048]
Args:
shape_to_merge: Shape to merge small dimensions.
max_dim: Maximal dimension of output shape used in merging.
Returns:
Merged shape.
"""
if shape_to_merge and np.all(np.array(shape_to_merge) == 1):
return [1]
resulting_shape = []
product = 1
for d in shape_to_merge:
if product * d <= max_dim:
product *= d
else:
if product > 1:
resulting_shape.append(product)
product = d
if product > 1:
resulting_shape.append(product)
return resulting_shape
def pad_square_matrix(mat, max_size):
"""Pad a square matrix up to max_size.
Args:
mat: a matrix to pad.
max_size: matrix size requested.
Returns:
Given M returns [[M, 0], [0, I]]
"""
rows, cols = mat.shape
if rows != cols:
raise ValueError("Must have rows == cols, instead got "
f"rows={rows}, cols={cols}")
if cols > max_size:
raise ValueError("Must have cols <= max_size. Instead got "
f"cols={cols}, max_size={max_size}.")
if rows == max_size:
return mat
pad_size = max_size - rows
zs1 = jnp.zeros([rows, pad_size], dtype=mat.dtype)
zs2 = jnp.zeros([pad_size, rows], dtype=mat.dtype)
eye = jnp.eye(pad_size, dtype=mat.dtype)
mat = jnp.concatenate([mat, zs1], 1)
mat = jnp.concatenate([mat, jnp.concatenate([zs2, eye], 1)], 0)
return mat
def pad_vector(vec, max_size):
"""Pad a vector to a max_size.
Args:
vec: a vector to pad.
max_size: matrix size requested.
Returns:
Given V returns [V, 0]
"""
size = vec.shape[0]
assert size <= max_size
if size == max_size:
return vec
pad_size = max_size - size
zs1 = jnp.zeros([pad_size], dtype=vec.dtype)
return jnp.concatenate([vec, zs1], 0)
def efficient_cond(predicate, compute_fn, init_state, *args, **kwargs):
"""Avoids wasteful buffer allocation with XLA."""
def _iter_body(unused_state):
results = compute_fn(*args, **kwargs)
return tuple([False] + list(results))
def _iter_condition(state):
return state[0]
results = jax.lax.while_loop(_iter_condition, _iter_body,
tuple([predicate] + init_state))
return tuple(results[1:])
class BlockPartitioner:
"""Partitions a tensor into smaller tensors."""
def __init__(self, param, block_size):
self._shape = param.shape
self._splits = []
split_sizes = []
# We split params into smaller blocks. Here we store the metadata to make
# that split.
for i, d in enumerate(param.shape):
if 0 < block_size < d:
# d-1, otherwise split appends a 0-size array.
nsplit = (d - 1) // block_size
indices = (np.arange(nsplit, dtype=np.int32) + 1) * block_size
sizes = np.ones(nsplit + 1, dtype=np.int32) * block_size
sizes[-1] = d - indices[-1]
self._splits.append((i, indices))
split_sizes.append(sizes)
else:
split_sizes.append(np.array([d], dtype=np.int32))
self._split_sizes = split_sizes
def split_sizes(self):
return self._split_sizes
def partition(self, tensor):
"""Partition tensor into blocks."""
assert tensor.shape == self._shape
tensors = [tensor]
for (i, indices) in self._splits:
tensors_local = []
for t in tensors:
tensors_local.extend(jnp.split(t, indices_or_sections=indices, axis=i))
tensors = tensors_local
return tensors
def merge_partitions(self, partitions):
"""Merge partitions back to original shape."""
for (i, indices) in reversed(self._splits):
n = len(indices) + 1
partial_merged_tensors = []
ind = 0
while ind < len(partitions):
partial_merged_tensors.append(
jnp.concatenate(partitions[ind:ind + n], axis=i))
ind += n
partitions = partial_merged_tensors
assert len(partitions) == 1
return partitions[0]
def gram_weighted_update(
old_stats,
g,
axis,
w1,
w2,
precision = None):
"""Updated statistics via weighted average with new Gram matrix.
Returns w₁ R + w₂ Gᵀ G where R is `old_stats` and G is the matrix whose
columns are the flattened slices of the tensor `g` along the given `axis`.
(So, `old_stats` and the returned matrix have dimensions n x n where
n = `g.shape[axis]`).
Args:
old_stats: Old statistics.
g: Gradient tensor.
axis: Axis along which to slice `g`.
w1: Scalar weight for old statistics.
w2: Scalar weight for new Gram matrix.
precision: Optional precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c) lax.Precision.HIGHEST
(best possible precision, slowest)
Returns:
Weighted average of old and new statistics.
"""
axes = [i for i in range(g.ndim) if i != axis]
gram_matrix = jnp.tensordot(g, g, axes=(axes, axes), precision=precision)
return w1 * old_stats + w2 * gram_matrix
class Preconditioner:
"""Compute statistics/shape from gradients for preconditioning."""
def __init__(
self,
param,
block_size,
merge_small_dims_block_size,
best_effort_shape_interpretation,
preconditioner_type=PreconditionerType.ALL,
):
"""Initializes the preconditioner.
Args:
param: parameter to precondition.
block_size: Block size used to split param.
merge_small_dims_block_size: Block size for merging dims.
best_effort_shape_interpretation: Whether to collapse/merge dims together.
preconditioner_type: Type of preconditioner to use.
"""
self._original_shape = param.shape
self._transformed_shape = param.shape
if best_effort_shape_interpretation:
self._transformed_shape = merge_small_dims(self._original_shape,
merge_small_dims_block_size)
reshaped_param = jnp.reshape(param, self._transformed_shape)
self._partitioner = BlockPartitioner(reshaped_param, block_size)
self._preconditioner_type = preconditioner_type
def updated_statistics_from_grad(
self,
stats,
grad,
w1,
w2,
to_float = None,
from_float = None,
precision = None,
):
"""Update statistics from gradients.
Args:
stats: Old statistics or its Cholesky factor if `cholesky` is True.
grad: Gradient to compute statistics from.
w1: Weight for old statistics.
w2: Weight for new statistics.
to_float: Optional function for converting stats to floating point.
from_float: Optional function for converting from floating point.
precision: Optional precision XLA related flag, the available options are:
a) lax.Precision.DEFAULT (better step time, but not precise) b)
lax.Precision.HIGH (increased precision, slower) c)
lax.Precision.HIGHEST (best possible precision, slowest)
Returns:
A list of updated gradient statistics for each partition.
"""