/
special_math_ops.py
1341 lines (1066 loc) · 46.4 KB
/
special_math_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2016 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Arithmetic Operations that don't fit into math_ops due to dependencies.
To avoid circular dependencies, some math_ops should go here.
"""
import collections
import functools
import re
import string
import numpy as np
import opt_einsum
from tensorflow.compiler.tf2xla.ops import gen_xla_ops
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor as tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_linalg_ops
from tensorflow.python.ops import gen_special_math_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.util import deprecation
from tensorflow.python.util import dispatch
from tensorflow.python.util.tf_export import tf_export
# TODO(b/27419586) Change docstring for required dtype of x once int allowed
@tf_export('math.lbeta', v1=['math.lbeta', 'lbeta'])
@dispatch.add_dispatch_support
@deprecation.deprecated_endpoints('lbeta')
def lbeta(x, name=None):
r"""Computes \\(ln(|Beta(x)|)\\), reducing along the last dimension.
Given one-dimensional $z = [z_1,...,z_K]$, we define
$$Beta(z) = \frac{\prod_j \Gamma(z_j)}{\Gamma(\sum_j z_j)},$$
where $\Gamma$ is the gamma function.
And for $n + 1$ dimensional $x$ with shape $[N_1, ..., N_n, K]$, we define
$$lbeta(x)[i_1, ..., i_n] = \log{|Beta(x[i_1, ..., i_n, :])|}.$$
In other words, the last dimension is treated as the $z$ vector.
Note that if $z = [u, v]$, then
$$Beta(z) = \frac{\Gamma(u)\Gamma(v)}{\Gamma(u + v)}
= \int_0^1 t^{u-1} (1 - t)^{v-1} \mathrm{d}t,$$
which defines the traditional bivariate beta function.
If the last dimension is empty, we follow the convention that the sum over
the empty set is zero, and the product is one.
Args:
x: A rank `n + 1` `Tensor`, `n >= 0` with type `float`, or `double`.
name: A name for the operation (optional).
Returns:
The logarithm of \\(|Beta(x)|\\) reducing along the last dimension.
"""
# In the event that the last dimension has zero entries, we return -inf.
# This is consistent with a convention that the sum over the empty set 0, and
# the product is 1.
# This is standard. See https://en.wikipedia.org/wiki/Empty_set.
with ops.name_scope(name, 'lbeta', [x]):
x = ops.convert_to_tensor(x, name='x')
# Note reduce_sum([]) = 0.
log_prod_gamma_x = math_ops.reduce_sum(math_ops.lgamma(x), axis=[-1])
# Note lgamma(0) = infinity, so if x = []
# log_gamma_sum_x = lgamma(0) = infinity, and
# log_prod_gamma_x = lgamma(1) = 0,
# so result = -infinity
sum_x = math_ops.reduce_sum(x, axis=[-1])
log_gamma_sum_x = math_ops.lgamma(sum_x)
result = log_prod_gamma_x - log_gamma_sum_x
return result
@tf_export('math.special.dawsn')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def dawsn(x, name=None):
"""Computes Dawson's integral of `x` element-wise.
Dawson's integral is defined as `exp(-x**2)` times the integral of
`exp(t**2)` from `0` to `x`, with the domain of definition all real numbers.
Dawson's function is odd.
>>> tf.math.special.dawsn([-1., -0.5, 0.5, 1.]).numpy()
array([-0.5380795, -0.4244364, 0.4244364, 0.5380795], dtype=float32)
This implementation is based off of the Cephes math library.
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types:
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.dawsn
@end_compatibility
"""
with ops.name_scope(name, 'dawsn', [x]):
return gen_special_math_ops.dawsn(x)
@tf_export('math.special.expint')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def expint(x, name=None):
"""Computes the Exponential integral of `x` element-wise.
The Exponential integral is defined as the integral of `exp(t) / t` from
`-inf` to `x`, with the domain of definition all positive real numbers.
>>> tf.math.special.expint([1., 1.1, 2.1, 4.1]).numpy()
array([ 1.8951179, 2.1673784, 5.3332353, 21.048464], dtype=float32)
This implementation is based off of the Cephes math library.
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types:
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.expi
@end_compatibility
"""
with ops.name_scope(name, 'expint', [x]):
return gen_special_math_ops.expint(x)
@tf_export('math.special.fresnel_cos')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def fresnel_cos(x, name=None):
"""Computes Fresnel's cosine integral of `x` element-wise.
The Fresnel cosine integral is defined as the integral of `cos(t^2)` from
`0` to `x`, with the domain of definition all real numbers.
The Fresnel cosine integral is odd.
>>> tf.math.special.fresnel_cos([-1., -0.1, 0.1, 1.]).numpy()
array([-0.7798934 , -0.09999753, 0.09999753, 0.7798934 ], dtype=float32)
This implementation is based off of the Cephes math library.
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types:
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.fresnel second output.
@end_compatibility
"""
with ops.name_scope(name, 'fresnel_cos', [x]):
return gen_special_math_ops.fresnel_cos(x)
@tf_export('math.special.fresnel_sin')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def fresnel_sin(x, name=None):
"""Computes Fresnel's sine integral of `x` element-wise.
The Fresnel sine integral is defined as the integral of `sin(t^2)` from
`0` to `x`, with the domain of definition all real numbers.
>>> tf.math.special.fresnel_sin([-1., -0.1, 0.1, 1.]).numpy()
array([-0.43825912, -0.00052359, 0.00052359, 0.43825912], dtype=float32)
This implementation is based off of the Cephes math library.
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types:
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.fresnel first output.
@end_compatibility
"""
with ops.name_scope(name, 'fresnel_sin', [x]):
return gen_special_math_ops.fresnel_sin(x)
@tf_export('math.special.spence')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def spence(x, name=None):
"""Computes Spence's integral of `x` element-wise.
Spence's integral is defined as the integral of `log(t) / (1 - t)` from
`1` to `x`, with the domain of definition all non-negative real numbers.
>>> tf.math.special.spence([0.5, 1., 2., 3.]).numpy()
array([ 0.58224034, 0. , -0.82246685, -1.4367464], dtype=float32)
This implementation is based off of the Cephes math library.
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types:
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.spence
@end_compatibility
"""
with ops.name_scope(name, 'spence', [x]):
return gen_special_math_ops.spence(x)
@tf_export('math.bessel_i0', 'math.special.bessel_i0')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_i0(x, name=None):
"""Computes the Bessel i0 function of `x` element-wise.
Modified Bessel function of order 0.
It is preferable to use the numerically stabler function `i0e(x)` instead.
>>> tf.math.special.bessel_i0([-1., -0.5, 0.5, 1.]).numpy()
array([1.26606588, 1.06348337, 1.06348337, 1.26606588], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i0', [x]):
return gen_special_math_ops.bessel_i0(x)
@tf_export('math.bessel_i0e', 'math.special.bessel_i0e')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_i0e(x, name=None):
"""Computes the Bessel i0e function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_i0e([-1., -0.5, 0.5, 1.]).numpy()
array([0.46575961, 0.64503527, 0.64503527, 0.46575961], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i0e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i0e', [x]):
return gen_special_math_ops.bessel_i0e(x)
@tf_export('math.bessel_i1', 'math.special.bessel_i1')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_i1(x, name=None):
"""Computes the Bessel i1 function of `x` element-wise.
Modified Bessel function of order 1.
It is preferable to use the numerically stabler function `i1e(x)` instead.
>>> tf.math.special.bessel_i1([-1., -0.5, 0.5, 1.]).numpy()
array([-0.5651591 , -0.25789431, 0.25789431, 0.5651591 ], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i1', [x]):
return gen_special_math_ops.bessel_i1(x)
@tf_export('math.bessel_i1e', 'math.special.bessel_i1e')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_i1e(x, name=None):
"""Computes the Bessel i1e function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_i1e([-1., -0.5, 0.5, 1.]).numpy()
array([-0.20791042, -0.15642083, 0.15642083, 0.20791042], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.i1e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_i1e', [x]):
return gen_special_math_ops.bessel_i1e(x)
@tf_export('math.special.bessel_k0')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_k0(x, name=None):
"""Computes the Bessel k0 function of `x` element-wise.
Modified Bessel function of order 0.
It is preferable to use the numerically stabler function `k0e(x)` instead.
>>> tf.math.special.bessel_k0([0.5, 1., 2., 4.]).numpy()
array([0.92441907, 0.42102444, 0.11389387, 0.01115968], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k0', [x]):
return gen_special_math_ops.bessel_k0(x)
@tf_export('math.special.bessel_k0e')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_k0e(x, name=None):
"""Computes the Bessel k0e function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_k0e([0.5, 1., 2., 4.]).numpy()
array([1.52410939, 1.14446308, 0.84156822, 0.60929767], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k0e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k0e', [x]):
return gen_special_math_ops.bessel_k0e(x)
@tf_export('math.special.bessel_k1')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_k1(x, name=None):
"""Computes the Bessel k1 function of `x` element-wise.
Modified Bessel function of order 1.
It is preferable to use the numerically stabler function `k1e(x)` instead.
>>> tf.math.special.bessel_k1([0.5, 1., 2., 4.]).numpy()
array([1.65644112, 0.60190723, 0.13986588, 0.0124835 ], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k1', [x]):
return gen_special_math_ops.bessel_k1(x)
@tf_export('math.special.bessel_k1e')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_k1e(x, name=None):
"""Computes the Bessel k1e function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_k1e([0.5, 1., 2., 4.]).numpy()
array([2.73100971, 1.63615349, 1.03347685, 0.68157595], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.k1e
@end_compatibility
"""
with ops.name_scope(name, 'bessel_k1e', [x]):
return gen_special_math_ops.bessel_k1e(x)
@tf_export('math.special.bessel_j0')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_j0(x, name=None):
"""Computes the Bessel j0 function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_j0([0.5, 1., 2., 4.]).numpy()
array([ 0.93846981, 0.76519769, 0.22389078, -0.39714981], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.j0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_j0', [x]):
return gen_special_math_ops.bessel_j0(x)
@tf_export('math.special.bessel_j1')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_j1(x, name=None):
"""Computes the Bessel j1 function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_j1([0.5, 1., 2., 4.]).numpy()
array([ 0.24226846, 0.44005059, 0.57672481, -0.06604333], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.j1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_j1', [x]):
return gen_special_math_ops.bessel_j1(x)
@tf_export('math.special.bessel_y0')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_y0(x, name=None):
"""Computes the Bessel y0 function of `x` element-wise.
Modified Bessel function of order 0.
>>> tf.math.special.bessel_y0([0.5, 1., 2., 4.]).numpy()
array([-0.44451873, 0.08825696, 0.51037567, -0.01694074], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.y0
@end_compatibility
"""
with ops.name_scope(name, 'bessel_y0', [x]):
return gen_special_math_ops.bessel_y0(x)
@tf_export('math.special.bessel_y1')
@dispatch.register_unary_elementwise_api
@dispatch.add_dispatch_support
def bessel_y1(x, name=None):
"""Computes the Bessel y1 function of `x` element-wise.
Modified Bessel function of order 1.
>>> tf.math.special.bessel_y1([0.5, 1., 2., 4.]).numpy()
array([-1.47147239, -0.78121282, -0.10703243, 0.39792571], dtype=float32)
Args:
x: A `Tensor` or `SparseTensor`. Must be one of the following types: `half`,
`float32`, `float64`.
name: A name for the operation (optional).
Returns:
A `Tensor` or `SparseTensor`, respectively. Has the same type as `x`.
@compatibility(scipy)
Equivalent to scipy.special.y1
@end_compatibility
"""
with ops.name_scope(name, 'bessel_y1', [x]):
return gen_special_math_ops.bessel_y1(x)
@ops.RegisterGradient('XlaEinsum')
def _einsum_grad(op, grad):
equation = op.get_attr('equation')
if isinstance(equation, bytes):
equation = equation.decode()
inputs, output = equation.split('->')
left, right = inputs.split(',')
return [
gen_xla_ops.xla_einsum(
grad,
op.inputs[1],
equation='{},{}->{}'.format(output, right, left),
name=None),
gen_xla_ops.xla_einsum(
grad,
op.inputs[0],
equation='{},{}->{}'.format(output, left, right),
name=None)
]
def _enclosing_tpu_context():
# pylint: disable=protected-access
context = ops.get_default_graph()._get_control_flow_context()
# pylint: enable=protected-access
while context is not None and not isinstance(
context, control_flow_ops.XLAControlFlowContext):
context = context.outer_context
return context
@tf_export('einsum', 'linalg.einsum')
@dispatch.add_dispatch_support
def einsum(equation, *inputs, **kwargs):
r"""Tensor contraction over specified indices and outer product.
Einsum allows defining Tensors by defining their element-wise computation.
This computation is defined by `equation`, a shorthand form based on Einstein
summation. As an example, consider multiplying two matrices A and B to form a
matrix C. The elements of C are given by:
$$ C_{i,k} = \sum_j A_{i,j} B_{j,k} $$
or
```
C[i,k] = sum_j A[i,j] * B[j,k]
```
The corresponding einsum `equation` is:
```
ij,jk->ik
```
In general, to convert the element-wise equation into the `equation` string,
use the following procedure (intermediate strings for matrix multiplication
example provided in parentheses):
1. remove variable names, brackets, and commas, (`ik = sum_j ij * jk`)
2. replace "*" with ",", (`ik = sum_j ij , jk`)
3. drop summation signs, and (`ik = ij, jk`)
4. move the output to the right, while replacing "=" with "->". (`ij,jk->ik`)
Note: If the output indices are not specified repeated indices are summed.
So `ij,jk->ik` can be simplified to `ij,jk`.
Many common operations can be expressed in this way. For example:
**Matrix multiplication**
>>> m0 = tf.random.normal(shape=[2, 3])
>>> m1 = tf.random.normal(shape=[3, 5])
>>> e = tf.einsum('ij,jk->ik', m0, m1)
>>> # output[i,k] = sum_j m0[i,j] * m1[j, k]
>>> print(e.shape)
(2, 5)
Repeated indices are summed if the output indices are not specified.
>>> e = tf.einsum('ij,jk', m0, m1) # output[i,k] = sum_j m0[i,j] * m1[j, k]
>>> print(e.shape)
(2, 5)
**Dot product**
>>> u = tf.random.normal(shape=[5])
>>> v = tf.random.normal(shape=[5])
>>> e = tf.einsum('i,i->', u, v) # output = sum_i u[i]*v[i]
>>> print(e.shape)
()
**Outer product**
>>> u = tf.random.normal(shape=[3])
>>> v = tf.random.normal(shape=[5])
>>> e = tf.einsum('i,j->ij', u, v) # output[i,j] = u[i]*v[j]
>>> print(e.shape)
(3, 5)
**Transpose**
>>> m = tf.ones(2,3)
>>> e = tf.einsum('ij->ji', m0) # output[j,i] = m0[i,j]
>>> print(e.shape)
(3, 2)
**Diag**
>>> m = tf.reshape(tf.range(9), [3,3])
>>> diag = tf.einsum('ii->i', m)
>>> print(diag.shape)
(3,)
**Trace**
>>> # Repeated indices are summed.
>>> trace = tf.einsum('ii', m) # output[j,i] = trace(m) = sum_i m[i, i]
>>> assert trace == sum(diag)
>>> print(trace.shape)
()
**Batch matrix multiplication**
>>> s = tf.random.normal(shape=[7,5,3])
>>> t = tf.random.normal(shape=[7,3,2])
>>> e = tf.einsum('bij,bjk->bik', s, t)
>>> # output[a,i,k] = sum_j s[a,i,j] * t[a, j, k]
>>> print(e.shape)
(7, 5, 2)
This method does not support broadcasting on named-axes. All axes with
matching labels should have the same length. If you have length-1 axes,
use `tf.squeeze` or `tf.reshape` to eliminate them.
To write code that is agnostic to the number of indices in the input
use an ellipsis. The ellipsis is a placeholder for "whatever other indices
fit here".
For example, to perform a NumPy-style broadcasting-batch-matrix multiplication
where the matrix multiply acts on the last two axes of the input, use:
>>> s = tf.random.normal(shape=[11, 7, 5, 3])
>>> t = tf.random.normal(shape=[11, 7, 3, 2])
>>> e = tf.einsum('...ij,...jk->...ik', s, t)
>>> print(e.shape)
(11, 7, 5, 2)
Einsum **will** broadcast over axes covered by the ellipsis.
>>> s = tf.random.normal(shape=[11, 1, 5, 3])
>>> t = tf.random.normal(shape=[1, 7, 3, 2])
>>> e = tf.einsum('...ij,...jk->...ik', s, t)
>>> print(e.shape)
(11, 7, 5, 2)
Args:
equation: a `str` describing the contraction, in the same format as
`numpy.einsum`.
*inputs: the inputs to contract (each one a `Tensor`), whose shapes should
be consistent with `equation`.
**kwargs:
- optimize: Optimization strategy to use to find contraction path using
opt_einsum. Must be 'greedy', 'optimal', 'branch-2', 'branch-all' or
'auto'. (optional, default: 'greedy').
- name: A name for the operation (optional).
Returns:
The contracted `Tensor`, with shape determined by `equation`.
Raises:
ValueError: If
- the format of `equation` is incorrect,
- number of inputs or their shapes are inconsistent with `equation`.
"""
return _einsum_v2(equation, *inputs, **kwargs)
def _einsum_v1(equation, *inputs, **kwargs):
"""Legacy implementation of einsum without using EinsumOp."""
name = kwargs.pop('name', None)
if kwargs:
raise TypeError(
f'Invalid keyword arguments for this function: '
f'{", ".join([format(key) for key in sorted(list(kwargs.keys()))])}.'
f' Expected: name.')
with ops.name_scope(name, 'einsum', [equation, inputs]) as name:
inputs = list(inputs)
input_shapes = [x.shape for x in inputs]
input_axis_labels, output_axis_labels = (
_einsum_v1_parse_and_resolve_equation(equation, input_shapes))
axis_labels = set(''.join(input_axis_labels) + output_axis_labels)
for a in axis_labels:
for input_labels in input_axis_labels:
if (len(input_axis_labels) == 1 and input_labels.count(a) == 2 and
input_labels == input_labels[::-1] and '->' not in equation):
return math_ops.trace(inputs[0])
if input_labels.count(a) > 1:
raise ValueError(
f'Subscript not supported: the axis {a} appears more than once'
f' in {input_labels}.')
for a in axis_labels:
input_count = sum(1 for s in input_axis_labels if a in s)
if input_count > 2 and a not in output_axis_labels:
logging.warn(
f'Falling back to exponential-space implementation of einsum()'
f' because index {a} is summed over more than two inputs.')
return _exponential_space_einsum_v1(equation, *inputs)
# Use xla_einsum if executing on TPU and if the operation is a 2 input
# einsum supported by XlaEinsumOp.
if _enclosing_tpu_context() is not None and len(inputs) == 2:
return gen_xla_ops.xla_einsum(
inputs[0], inputs[1], input_axis_labels[0] + ',' +
input_axis_labels[1] + '->' + output_axis_labels)
temp = inputs[0]
temp_axis_labels = input_axis_labels[0]
for i in range(len(inputs) - 1):
axes_to_sum = (
set(temp_axis_labels) &
set(input_axis_labels[i + 1]) - set(output_axis_labels))
temp, temp_axis_labels = _einsum_v1_reduction(temp, temp_axis_labels,
inputs[i + 1],
input_axis_labels[i + 1],
axes_to_sum)
missing_indices = set(temp_axis_labels) - set(output_axis_labels)
if missing_indices:
axis = [
i for i, a in enumerate(temp_axis_labels)
if a not in output_axis_labels
]
temp = math_ops.reduce_sum(temp, axis=axis)
temp_axis_labels = ''.join(
a for a in temp_axis_labels if a in output_axis_labels)
if sorted(temp_axis_labels) != sorted(output_axis_labels):
raise ValueError(
f'Invalid equation: {equation}. The computed and specified output '
f'labels do not match: {temp_axis_labels} vs {output_axis_labels}.')
perm = [temp_axis_labels.index(a) for a in output_axis_labels]
return _transpose_if_necessary(temp, perm)
def _einsum_v1_parse_and_resolve_equation(equation, input_shapes):
"""Helper for einsum() that splits/resolves inputs & outputs.
Args:
equation: Equation string given as argument to einsum().
input_shapes: List of the shapes of all inputs given to einsum()
Returns:
input_axis_labels, output_axis_labels where:
input_axis_labels: List of length len(input_shapes) of strings
representing the character label for each dimension of each given input,
resolving any broadcast (...) axes,
output_axis_labels: A string of character labels for each axes of output
tensor, filling in missing output subscripts and broadcast axes.
Raises:
ValueError: If equation is in the uncorrect format, incorrect number of
inputs given or broadcast axes "..." or output axes could not be resolved.
"""
equation = equation.replace(' ', '')
match = re.match('^([a-zA-Z,.]+)(->[a-zA-Z.]*)?$', equation)
if not match:
raise ValueError(f'Indices have incorrect format. Received: {equation}.')
input_axis_labels = match.group(1).split(',')
output_axis_labels = match.group(2)[2:] if match.group(2) else None
if len(input_shapes) != len(input_axis_labels):
raise ValueError(
f'Got {len(input_shapes)} arguments for equation "{equation}", '
f'expecting {len(input_axis_labels)}.')
# Resolve Ellipsis
# Assign axes labels for unspecified dimensions in inputs. Labels taken
# from unused labels. Follow numpy einsum broadcasting conventions for
# tensors of different length and unlabeled output.
ellipsis_axes = ''
if '...' in equation:
unused = ''.join(
c for c in string.ascii_letters if c not in ''.join(input_axis_labels))
for i, ax in enumerate(input_axis_labels):
if '...' in ax:
parts = ax.split('...')
if len(parts) != 2:
raise ValueError(f'Unable to resolve ellipsis. '
f'Excess number found: {len(parts)-1} vs 1.')
if input_shapes[i].ndims is None:
raise ValueError('Unable to statically infer ellipsis axes. The '
'input shapes has a dynamic dimensionality.')
n = input_shapes[i].ndims - len(''.join(parts))
if n < 0:
raise ValueError('Ellipses lengths do not match.')
if len(unused) < n:
raise ValueError(
'Unable to resolve ellipsis, too many distinct labels.')
replace_axes = unused[-n:] if n > 0 else ''
input_axis_labels[i] = input_axis_labels[i].replace('...',
replace_axes)
if len(replace_axes) > len(ellipsis_axes):
ellipsis_axes = replace_axes
if any('.' in ax for ax in input_axis_labels):
raise ValueError(
f'Period "." found outside of ellipsis in input {input_axis_labels}.')
if output_axis_labels is not None:
output_axis_labels = output_axis_labels.replace('...', ellipsis_axes)
if '.' in output_axis_labels:
raise ValueError(f'Period "." found outside of ellipsis in output '
f'{output_axis_labels}.')
if output_axis_labels is None:
# infer the output subscripts if not given, assume alphabetical order,
# but always place ellipsis axes before given.
axis_labels = set(''.join(input_axis_labels)) - set(ellipsis_axes)
indices = ''.join(sorted(axis_labels))
counts = {ax: 0 for ax in indices}
for axes_ in input_axis_labels:
for ax in axes_:
if ax not in ellipsis_axes:
counts[ax] += 1
output_axis_labels = ellipsis_axes + ''.join(
sorted(ax for ax in axis_labels if counts[ax] == 1))
return input_axis_labels, output_axis_labels
def _einsum_v1_reduction(t0, t0_axis_labels, t1, t1_axis_labels, axes_to_sum):
"""Helper for einsum() that computes the result of a two-argument einsum().
Args:
t0: a `Tensor`
t0_axis_labels: a string of axis labels. This string's length must equal
the rank of t0.
t1: a `Tensor`
t1_axis_labels: a string to axis labels. This string's length must equal
the rank of t1.
axes_to_sum: set of labels of axes to be summed over
Returns:
A `Tensor` whose elements are obtained by summing, over all axes in
`axes_to_sum`, the corresponding elements of `t0` and `t1`.
For example, if t0_axis_labels == 'abijk', t1_axis_labels == 'acjkl', and
axes_to_sum == {j,k}, this will return a tensor x where
out[a,b,c,i,l] = sum_j sum_k t0[a,b,i,j,k] * t1[a,c,j,k,l]
Raises:
ValueError: if the rank of `t0` does not match the length of
`t0_axis_labels`, or that of `t1` does not match the length of
`t1_axis_labels`.
"""
if len(t0_axis_labels) != len(t0.shape):
raise ValueError(
f'Tensor `t0` of rank {len(t0.shape)} does not match einsum reduction '
f'of length {len(t0_axis_labels)}.')
if len(t1_axis_labels) != len(t1.shape):
raise ValueError(
f'Tensor `t1` of rank {len(t1.shape)} does not match einsum reduction '
f'of length {len(t1_axis_labels)}')
# This function computes the result of a two-argument einsum() using batch
# matrix multiplication. This involves
# 1. transposing t0 and t1 so that axes are in the correct order for
# batch matrix multiplication, and
# 2. reshaping t0 and t1 so that they are both of rank 3.
# First, we divide axes into three groups:
# * "preserved" axes are present in both inputs and the output
# * "summed" axes are present in both inputs but not the output
# * "broadcast" axes are present in exactly one input and the output
#
# As an example, if the einsum is abijk,acjkl->abcil, then "a" is a
# preserved axis, "b" and "c" are broadcast axes, and "j" and "k" are
# summed axes.
assert all(a in t0_axis_labels and a in t1_axis_labels for a in axes_to_sum)
preserved_axes = (set(t0_axis_labels) & set(t1_axis_labels)) - axes_to_sum
broadcast_axes = {}
for i, sym_list in enumerate([t0_axis_labels, t1_axis_labels]):
broadcast_axes[i] = set(sym_list) - preserved_axes - axes_to_sum
# Reorder the axes so that:
# 1. preserved axes come first in both inputs
# 2. in input 0, broadcast axes come next, followed by summed axes
# 3. in input 1, summed axes come next, followed by broadcast axes
def sort_key(input_index, a):
if a in preserved_axes:
return (-1, a)
elif ((input_index == 0 and a in broadcast_axes[0]) or
(input_index == 1 and a in axes_to_sum)):
return (0, a)
else:
return (1, a)
axis_labels = [t0_axis_labels, t1_axis_labels]
sorted_axes = [
sorted(sym_list, key=lambda a: sort_key(i, a))
for i, sym_list in enumerate(axis_labels)
]
inputs = [t0, t1]
for i, axes_str in enumerate(axis_labels):
perm = [axes_str.find(a) for a in sorted_axes[i]]
inputs[i] = _transpose_if_necessary(inputs[i], perm)
t0, t1 = inputs