/
hlo_instruction.h
1821 lines (1530 loc) · 74.3 KB
/
hlo_instruction.h
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
// HLO instructions are in DAG form and represent the computations that the user
// has built up via the XLA service interface. They are ultimately lowered
// in a platform-aware way by traversing the HLO DAG and emitting a lowered
// form; e.g. see DfsHloVisitor.
#ifndef TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
#define TENSORFLOW_COMPILER_XLA_SERVICE_HLO_INSTRUCTION_H_
#include <functional>
#include <iosfwd>
#include <list>
#include <memory>
#include <set>
#include <string>
#include <tuple>
#include <unordered_map>
#include <unordered_set>
#include <vector>
#include "tensorflow/compiler/xla/iterator_util.h"
#include "tensorflow/compiler/xla/literal_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/dfs_hlo_visitor.h"
#include "tensorflow/compiler/xla/service/hlo.pb.h"
#include "tensorflow/compiler/xla/service/hlo_clone_context.h"
#include "tensorflow/compiler/xla/service/hlo_domain_metadata.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/hlo_sharding.h"
#include "tensorflow/compiler/xla/service/name_uniquer.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/xla_data.pb.h"
#include "tensorflow/core/lib/core/status.h"
#include "tensorflow/core/lib/core/stringpiece.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/gtl/flatmap.h"
#include "tensorflow/core/lib/gtl/inlined_vector.h"
#include "tensorflow/core/lib/gtl/iterator_range.h"
#include "tensorflow/core/platform/logging.h"
#include "tensorflow/core/platform/macros.h"
#include "tensorflow/core/platform/protobuf.h"
#include "tensorflow/core/platform/types.h"
namespace xla {
class HloComputation;
class HloModule;
// A bunch of switches that control how the hlo text should be printed.
class HloPrintOptions {
public:
enum class PrintSubcomputationMode {
kOff, // Do not print anything about subcomputations.
kNameOnly, // Only print the name of subcomputations.
kFullBodies, // Print the full bodies of subcomputations.
};
// Constructs the default print options: don't print large constants, don't
// compact operands, no indentation.
HloPrintOptions()
: print_large_constants_(false),
print_subcomputation_mode_(PrintSubcomputationMode::kNameOnly),
print_metadata_(true),
print_backend_config_(true),
compact_operands_(false),
print_operand_shape_(true),
print_program_shape_(true),
print_percent_(true),
canonicalize_instruction_names_(false),
indent_amount_(0),
is_in_nested_computation_(false) {}
static HloPrintOptions ShortParsable() {
return HloPrintOptions()
.set_print_large_constants(true)
.set_print_subcomputation_mode(PrintSubcomputationMode::kNameOnly)
.set_print_metadata(false)
.set_print_backend_config(false)
.set_print_operand_shape(false)
.set_print_program_shape(false)
.set_print_percent(false);
}
// Options to produce the canonical string representing an isomorphic
// computation graph.
static HloPrintOptions Canonical() {
return HloPrintOptions()
.set_print_subcomputation_mode(PrintSubcomputationMode::kFullBodies)
.set_print_metadata(false)
.set_compact_operands(true)
.set_print_operand_shape(true)
.set_print_program_shape(false)
.set_print_percent(false)
.set_canonicalize_instruction_names(true);
}
// If true, large constants will be printed out.
HloPrintOptions& set_print_large_constants(bool value) {
print_large_constants_ = value;
return *this;
}
HloPrintOptions& set_print_subcomputation_mode(
PrintSubcomputationMode value) {
print_subcomputation_mode_ = value;
return *this;
}
// If true, metadata will be printed.
HloPrintOptions& set_print_metadata(bool value) {
print_metadata_ = value;
return *this;
}
// If true, backend_config will be printed.
HloPrintOptions& set_print_backend_config(bool value) {
print_backend_config_ = value;
return *this;
}
// If true, operands' shapes will be printed.
HloPrintOptions& set_print_operand_shape(bool value) {
print_operand_shape_ = value;
return *this;
}
// If true, program shape of hlo computations will be printed.
HloPrintOptions& set_print_program_shape(bool value) {
print_program_shape_ = value;
return *this;
}
// If true, names will be printed with prefix '%'.
HloPrintOptions& set_print_percent(bool value) {
print_percent_ = value;
return *this;
}
// If true, only a part of operands will be printed out, and their names will
// be omitted (note that in this case the text will not be parsable).
HloPrintOptions& set_compact_operands(bool value) {
compact_operands_ = value;
return *this;
}
// If true, canonicalizes instructions' name. Instead of using "%foo.1" as
// the name of an instruction, we use "%tmp_1", "%tmp_2" etc.
HloPrintOptions& set_canonicalize_instruction_names(bool value) {
canonicalize_instruction_names_ = value;
return *this;
}
// The indent of the hlo text block.
HloPrintOptions& set_indent_amount(int value) {
indent_amount_ = value;
return *this;
}
// If true, indicates the instruction being printed is inside a nested
// computation.
HloPrintOptions& set_is_in_nested_computation(bool value) {
is_in_nested_computation_ = value;
return *this;
}
bool print_large_constants() const { return print_large_constants_; }
PrintSubcomputationMode print_subcomputation_mode() const {
return print_subcomputation_mode_;
}
bool print_metadata() const { return print_metadata_; }
bool print_backend_config() const { return print_metadata_; }
bool compact_operands() const { return compact_operands_; }
bool print_operand_shape() const { return print_operand_shape_; }
bool print_program_shape() const { return print_program_shape_; }
bool print_percent() const { return print_percent_; }
bool canonicalize_instruction_names() const {
return canonicalize_instruction_names_;
}
int indent_amount() const { return indent_amount_; }
int is_in_nested_computation() const { return is_in_nested_computation_; }
private:
bool print_large_constants_;
PrintSubcomputationMode print_subcomputation_mode_;
bool print_metadata_;
bool print_backend_config_;
bool compact_operands_;
bool print_operand_shape_;
bool print_program_shape_;
bool print_percent_;
bool canonicalize_instruction_names_;
int indent_amount_;
bool is_in_nested_computation_;
};
// For canonical string output, we need to have a canonical way to rename
// each instruction and its operands. Each operand is renamed as "tmp_<xxx>",
// where <xxx> is an index starting from 0.
class CanonicalNameMap {
public:
CanonicalNameMap() : index(0) {}
string LookupOrInsert(const string& old_name) {
auto iter = canonical_name_map.find(old_name);
if (iter != canonical_name_map.end()) {
return iter->second;
}
string new_name = tensorflow::strings::StrCat("tmp_", index++);
canonical_name_map[old_name] = new_name;
return new_name;
}
void Clear() {
canonical_name_map.clear();
index = 0;
}
private:
int64 index;
tensorflow::gtl::FlatMap<string, string> canonical_name_map;
};
// HLO instructions are the atomic unit of the high-level compiler's IR.
//
// HloInstructions live inside of an HloComputation, which is analogous to a
// function in other programming languages. Nodes have no total order within
// their computation. Instead, they have a partial ordering determined by their
// data and control dependencies.
//
// HLO does not have basic blocks or explicit "branch" instructions. Instead,
// certain HloInstructions -- namely, kWhile, kConditional, and kCall -- encode
// control flow. For example, the kConditional HLO executes one of two possible
// computations, depending on the runtime value of a predicate.
//
// HLO is pure (mostly). It has no concept of mutable state. Instead, data
// values are produced by one HLO and flow into consumers across dependency
// edges.
class HloInstruction {
public:
// A fusion node computes the same value a call to its fusion computation
// would compute. However, the choice of fusion kind dictates codegen
// strategy for the backend.
//
// To generate code for a kFusion HloInstruction, most backends do something
// like the following:
//
// 1) Identify the "primary" HloInstruction of the fused computation.
// 2) Emit code that does the work of the primary node, creating its inputs
// and transforming its outputs as specified by the fused computation.
//
// In step (2), the code emitted is usually similar to the code that would be
// emitted for an *unfused* version of the primary node, except that
//
// - when the primary node reads an element of one of its operands, instead
// of loading the value from memory, it *computes* the value based on the
// contents of the fused computation.
// - when the primary node outputs a value, instead of storing it to memory,
// it forwards the value to its users, which then perform additional
// computations before the value is finally stored to memory at the root of
// the fusion node.
//
// An HloInstruction's FusionKind helps us find the kFusion instruction's
// primary node, and can also affect how we generate code in step (2).
//
// - kInput: The primary node is the root of the fused instruction.
//
// - kOutput: The primary node is not the root of the fused instruction.
// This fusion kind requires that one operand buffer of the fusion
// instruction be able to alias the output buffer. This constraint is
// usually enough to let backends find the primary node unambiguously.
//
// - kLoop: The primary node is the root of the fused computation, but,
// unlike in input fusion, we prescribe a specific implementation for
// codegen. Rather than generating code that looks like the code we'd emit
// for an unfused version of the primary/root node, we emit code that
// generates one element of the root at a time.
//
// - kCustom: Custom category for backend-specific fusions that don't fit
// into the above patterns.
//
// Not all backends support all fusion kinds, and given a particular fused
// computation, it's not in general safe to change its fusion kind. Creation
// of fusion nodes is always backend-specific.
//
// For elementwise ops (e.g. kAdd), most backends would emit a
// one-element-at-a-time implementation for the unfused version, so loop
// fusion and input fusion are probably equivalent if the root node is
// elementwise. They're not necessarily equivalent e.g. for kReduce, where an
// implementation might emit something more sophisticated for an unfused or
// input-fusion reduce, but will emit the naive code that reduces one element
// at a time for loop fusion with a reduce as the root.
//
// Another way to think of loop fusion is that it's equivalent to input
// fusion, but where the root node is an implicit identity node, whose
// unfused implementation is "read one element, write one element".
//
// TODO(b/79869434): This categorization scheme is not great. For one thing,
// input and loop fusion are basically the same thing: There is no reason for
// the HLO to encode backend-specific decisions about how e.g. a reduce that's
// the root of a fusion should be lowered. In addition, this scheme as
// written doesn't work for multi-output fusion, where the primary node is
// never actually the root (which is a kTuple instruction that gathers the
// multiple outputs of the fusion).
enum class FusionKind {
kLoop,
kInput,
kOutput,
kCustom,
};
~HloInstruction();
// Creates an instruction from the given proto. Arguments:
//
// proto: the proto to convert from.
// instruction_map: a map from instruction id to HloInstruction*. This map
// must contain all operands of the newly constructed instruction.
// computation_map: a map from computation id to HloComputation*. This map
// must contain all computations which the newly constructed instruction
// calls.
static StatusOr<std::unique_ptr<HloInstruction>> CreateFromProto(
const HloInstructionProto& proto,
const tensorflow::gtl::FlatMap<int64, HloInstruction*>& instruction_map,
const tensorflow::gtl::FlatMap<int64, HloComputation*>& computation_map);
// Creates a parameter-retrieving instruction.
static std::unique_ptr<HloInstruction> CreateParameter(int64 parameter_number,
const Shape& shape,
const string& name);
// Creates a literal constant instruction.
static std::unique_ptr<HloInstruction> CreateConstant(
std::unique_ptr<Literal> literal);
// Creates a get tuple element instruction.
static std::unique_ptr<HloInstruction> CreateGetTupleElement(
const Shape& shape, HloInstruction* operand, int64 index);
// Creates a trace instruction that logs the input operand in the computation.
static std::unique_ptr<HloInstruction> CreateTrace(const string& tag,
HloInstruction* operand);
// Creates a random number generation instruction that fills a shape with
// random numbers from a given distribution.
static std::unique_ptr<HloInstruction> CreateRng(
const Shape& shape, RandomDistribution distribution,
tensorflow::gtl::ArraySlice<HloInstruction*> parameters);
// Creates a unary instruction (one operand).
// Precondition: opcode must be a legitimate unary operation.
static std::unique_ptr<HloInstruction> CreateUnary(const Shape& shape,
HloOpcode opcode,
HloInstruction* operand);
// Creates a binary instruction (two operands).
// Precondition: opcode must be a legitimate binary operation.
static std::unique_ptr<HloInstruction> CreateBinary(const Shape& shape,
HloOpcode opcode,
HloInstruction* lhs,
HloInstruction* rhs);
// Creates a ternary instruction (three operands).
// Precondition: opcode must be a legitimate ternary operation.
static std::unique_ptr<HloInstruction> CreateTernary(const Shape& shape,
HloOpcode opcode,
HloInstruction* lhs,
HloInstruction* rhs,
HloInstruction* ehs);
// Creates a variadic instruction (variable number of operands).
// Precondition: opcode must be a legitimate variadic operation.
static std::unique_ptr<HloInstruction> CreateVariadic(
const Shape& shape, HloOpcode opcode,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// Creates a map instruction, where the computation (given by the handle) is
// applied element-wise to every element in operands (across the operands,
// at a given index) with the same `static_operands`.
static std::unique_ptr<HloInstruction> CreateMap(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* map_computation,
tensorflow::gtl::ArraySlice<HloInstruction*> static_operands = {});
// Creates a convolution op, where rhs is the convolutional filter
// and window describes how the filter is applied to lhs.
static std::unique_ptr<HloInstruction> CreateConvolve(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const Window& window,
const ConvolutionDimensionNumbers& dimension_numbers);
// Creates an FFT op, of the type indicated by fft_type.
static std::unique_ptr<HloInstruction> CreateFft(
const Shape& shape, HloInstruction* operand, FftType fft_type,
tensorflow::gtl::ArraySlice<int64> fft_length);
// Creates a dot op with operands 'lhs' and 'rhs' with contracting and batch
// dimensions specified in 'dimension_numbers'.
static std::unique_ptr<HloInstruction> CreateDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs,
const DotDimensionNumbers& dimension_numbers);
// Creates a dot op with operands 'lhs' and 'rhs' that contracts dimension 1
// of the LHS with dimension 0 of the RHS with no batch dimensions. Both LHS
// and the RHS must be of rank 2.
static std::unique_ptr<HloInstruction> CreateCanonicalDot(
const Shape& shape, HloInstruction* lhs, HloInstruction* rhs);
// Creates a reduce-precision op, where operand is the data to reduce in
// precision, and exponent_bits and mantissa_bits describe the precision to
// reduce it to.
static std::unique_ptr<HloInstruction> CreateReducePrecision(
const Shape& shape, HloInstruction* operand, const int exponent_bits,
const int mantissa_bits);
// Creates a cross replica sum op.
static std::unique_ptr<HloInstruction> CreateCrossReplicaSum(
const Shape& shape,
tensorflow::gtl::ArraySlice<HloInstruction*> operands);
// Creates a conversion instruction, where operand is the data to convert and
// shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateConvert(const Shape& shape,
HloInstruction* operand);
// Creates a bitcast conversion instruction, where operand is the data to
// convert and shape is the target shape for the conversion.
static std::unique_ptr<HloInstruction> CreateBitcastConvert(
const Shape& shape, HloInstruction* operand);
// Creates an infeed instruction, which reads data of the given shape from the
// Infeed interface of the device.
static std::unique_ptr<HloInstruction> CreateInfeed(const Shape& shape,
const string& config);
// Creates an outfeed instruction, which outputs data.
static std::unique_ptr<HloInstruction> CreateOutfeed(
const Shape& shape, HloInstruction* operand,
tensorflow::StringPiece outfeed_config);
// Creates an asynchronous send instruction with the given channel id, which
// initiates sending the operand data to a unique receive instruction in
// another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateSend(HloInstruction* operand,
int64 channel_id);
// Blocks until data transfer for the Send instruction (operand) is complete.
// The operand must be kSend.
static std::unique_ptr<HloInstruction> CreateSendDone(
HloInstruction* operand);
// Creates an asynchronous receive instruction with the given channel id,
// which allocates resources to receive data of the given shape from a unique
// send instruction in another computation that has the same channel id.
static std::unique_ptr<HloInstruction> CreateRecv(const Shape& shape,
int64 channel_id);
// Blocks until data transfer for the Recv instruction (operand) is complete
// and returns the receive buffer. The operand must be kRecv.
static std::unique_ptr<HloInstruction> CreateRecvDone(
HloInstruction* operand);
// Creates a slice instruction, where the operand is sliced by the given
// start/limit indices.
static std::unique_ptr<HloInstruction> CreateSlice(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides);
// Creates a slice instruction, where the first operand is sliced by
// start indices specified in the second operand, and by size specified in
// 'slice_sizes'.
static std::unique_ptr<HloInstruction> CreateDynamicSlice(
const Shape& shape, HloInstruction* operand,
HloInstruction* start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes);
// Creates a dynamic update slice instruction, which updates a slice
// of 'operand' with 'update' and 'start_indices'.
static std::unique_ptr<HloInstruction> CreateDynamicUpdateSlice(
const Shape& shape, HloInstruction* operand, HloInstruction* update,
HloInstruction* start_indices);
// Creates a concatenate instruction, where the operands are concatenated on
// the provided dimension.
static std::unique_ptr<HloInstruction> CreateConcatenate(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
int64 dimension);
// Creates a reduce instruction, where the computation (given by the handle)
// is applied successively to every element in operand. That is, if f is the
// function to apply (which either takes 2 [accumulator, value] or 3
// [accumulator, index, value] arguments) and init is a reduction operator
// specified initial value (for example, 0 for addition), then this operation
// will compute:
// f(f(init, [index0], value0), [index1], value1), ...)
static std::unique_ptr<HloInstruction> CreateReduce(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
tensorflow::gtl::ArraySlice<int64> dimensions_to_reduce,
HloComputation* reduce_computation);
// Creates a reduce-window instruction, where the computation (given
// by the handle) is applied window-wise at each valid window
// position in the operand.
static std::unique_ptr<HloInstruction> CreateReduceWindow(
const Shape& shape, HloInstruction* operand, HloInstruction* init_value,
const Window& window, HloComputation* reduce_computation);
// Creates a batch-norm-training instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormTraining(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, float epsilon, int64 feature_index);
// Creates a batch-norm-inference instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormInference(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* offset, HloInstruction* mean, HloInstruction* variance,
float epsilon, int64 feature_index);
// Creates a batch-norm-grad instruction.
static std::unique_ptr<HloInstruction> CreateBatchNormGrad(
const Shape& shape, HloInstruction* operand, HloInstruction* scale,
HloInstruction* mean, HloInstruction* variance,
HloInstruction* grad_output, float epsilon, int64 feature_index);
// Creates a scatter computation that scatters the `source` array to the
// selected indices of each window.
static std::unique_ptr<HloInstruction> CreateSelectAndScatter(
const Shape& shape, HloInstruction* operand, HloComputation* select,
const Window& window, HloInstruction* source, HloInstruction* init_value,
HloComputation* scatter);
// Creates a broadcast instruction.
static std::unique_ptr<HloInstruction> CreateBroadcast(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions);
// Creates a sequence of instructions that performs an explicit broadcast of
// the operand to the target shape.
//
// Interior HLOs are passed to "adder", but the "root" HLO of the sequence is
// returned as a unique_ptr for API consistency with other factory methods in
// this interface.
//
// TODO(b/72173833) Ideally HloComputations would always be present, and so
// the adder being passed by the caller would not be necessary.
static std::unique_ptr<HloInstruction> CreateBroadcastSequence(
const Shape& output_shape, HloInstruction* operand,
const std::function<HloInstruction*(std::unique_ptr<HloInstruction>)>&
adder);
// Creates a pad instruction, where the operand is padded on the edges and
// between the elements with the given padding value.
static std::unique_ptr<HloInstruction> CreatePad(
const Shape& shape, HloInstruction* operand,
HloInstruction* padding_value, const PaddingConfig& padding_config);
// Creates a reshape instruction, where the operand is flattened row-major
// order and then reshaped to the given result shape.
static std::unique_ptr<HloInstruction> CreateReshape(const Shape& shape,
HloInstruction* operand);
// Creates a transpose instruction which permutes the operand dimensions.
static std::unique_ptr<HloInstruction> CreateTranspose(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
// Creates a while instruction, given a condition computation, a body
// computation, and the initial value for the input of the computations. For
// example, shape: S32, condition: i -> i < 1000, body: i -> i * 2, init: 1
// corresponds to the C code below.
// int32 i = 1; int32 result = while(i < 1000) { i = i * 2 }
static std::unique_ptr<HloInstruction> CreateWhile(const Shape& shape,
HloComputation* condition,
HloComputation* body,
HloInstruction* init);
static std::unique_ptr<HloInstruction> CreateConditional(
const Shape& shape, HloInstruction* pred,
HloInstruction* true_computation_arg, HloComputation* true_computation,
HloInstruction* false_computation_arg, HloComputation* false_computation);
static std::unique_ptr<HloInstruction> CreateGather(
const Shape& shape, HloInstruction* operand,
HloInstruction* gather_indices,
const GatherDimensionNumbers& gather_dim_numbers,
tensorflow::gtl::ArraySlice<int64> window_bounds);
// Creates a kDomain instruction which delimits an HLO domain which have
// the provided user and operand side metadata.
static std::unique_ptr<HloInstruction> CreateDomain(
const Shape& shape, HloInstruction* operand,
std::unique_ptr<DomainMetadata> operand_side_metadata,
std::unique_ptr<DomainMetadata> user_side_metadata);
// Creates a fusion instruction. A fusion instruction contains one or more
// fused instructions forming an expression with a single root
// "fused_root". Additional instructions can be added to the fusion
// instruction with the method FuseInstruction.
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind, HloInstruction* fused_root);
static std::unique_ptr<HloInstruction> CreateFusion(
const Shape& shape, FusionKind fusion_kind,
tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* fusion_computation);
// Creates a call instruction that applies the given computation on the given
// operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* computation);
// Creates a custom call instruction that applies the given custom call target
// to the given operands. "shape" is the resultant shape.
static std::unique_ptr<HloInstruction> CreateCustomCall(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece custom_call_target);
// Creates a HostCompute instruction, which records host-side control and
// data dependencies for use in instruction scheduling.
static std::unique_ptr<HloInstruction> CreateHostCompute(
const Shape& shape, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
tensorflow::StringPiece channel_name, const int64 cost_estimate_ns);
// Creates a tuple instruction with the given elements. This is a convenience
// wrapper around CreateVariadic.
static std::unique_ptr<HloInstruction> CreateTuple(
tensorflow::gtl::ArraySlice<HloInstruction*> elements);
// Creates a reverse instruction, which reverses the order of the elements
// in the specified dimensions.
static std::unique_ptr<HloInstruction> CreateReverse(
const Shape& shape, HloInstruction* operand,
tensorflow::gtl::ArraySlice<int64> dimensions);
// Creates an instance of GatherDimensionNumbers.
static GatherDimensionNumbers MakeGatherDimNumbers(
tensorflow::gtl::ArraySlice<int64> output_window_dims,
tensorflow::gtl::ArraySlice<int64> elided_window_dims,
tensorflow::gtl::ArraySlice<int64> gather_dims_to_operand_dims,
int64 index_vector_dim);
// Returns the opcode for this instruction.
HloOpcode opcode() const { return opcode_; }
// Returns true if this instruction has a side effect, irrespective of whether
// any called computations may contain an instruction with side effects.
bool HasSideEffectNoRecurse() const;
// Returns true if this instruction has a side effect. An instruction has a
// side effect if it uses certain opcodes or calls a computation with a side
// effect.
bool HasSideEffect() const;
// Returns the result shape of this instruction.
const Shape& shape() const;
// Returns the (mutable) result shape of this instruction.
Shape* mutable_shape() { return &shape_; }
// Returns the ith operand to this instruction.
const HloInstruction* operand(int64 i) const;
// Returns the ith operand to this instruction.
HloInstruction* mutable_operand(int64 i);
// Returns the number of operands to this instruction.
int64 operand_count() const { return operands_.size(); }
// Returns the vector of operands of this instruction.
using InstructionVector = tensorflow::gtl::InlinedVector<HloInstruction*, 2>;
const InstructionVector& operands() const { return operands_; }
// Returns the vector of unique operands, in the same order they are found
// within the operand vector.
InstructionVector unique_operands() const;
// Returns the index of 'target' in the operands sequence.
// Precondition: target must be an operand (or a fatal error will occur).
int64 operand_index(const HloInstruction* target) const;
// Returns the number of users of this instruction.
int64 user_count() const { return users_.size(); }
// Returns the users of this instruction.
const std::vector<HloInstruction*>& users() const { return users_; }
// Returns true if this instruction is a user of 'instruction'.
bool IsUserOf(const HloInstruction* instruction) const {
return ContainsKey(instruction->user_set_, this);
}
// Adds a control dependency from this instruction to the given
// instruction. This instruction becomes a control predecessor of
// 'instruction', and 'instruction' becomes a control successor of this
// instruction. Returns an error status if either of the given instructions
// does not belong to the same computation.
//
// This is used to enforce an additional ordering requirement that is not
// captured by normal data dependencies, such as ordering among Send or Recv
// operations to avoid deadlock.
Status AddControlDependencyTo(HloInstruction* instruction);
// Removes a previously added control dependency from this instruction to
// 'instruction'.
Status RemoveControlDependencyTo(HloInstruction* instruction);
// Drops all control predecessors and successors from this HLO instruction.
Status DropAllControlDeps();
// Copies the control predecessors and successors on this HLO instruction to
// `inst`. Does not do a deep copy so this makes sense only if `inst` and
// this HLO are in the same module.
//
// Depending on the use cases we see in practice, in the future we may
// consider folding the logic here into Clone, CloneWithNewOperands and
// ReplaceAllUsesWith by treating control dependencies like data dependencies.
Status CopyAllControlDepsFrom(const HloInstruction* inst);
// Returns the set of control predecessors (successors) of this
// instruction. Control predecessors (successors) must execute before (after)
// the current instruction.
const std::vector<HloInstruction*>& control_predecessors() const {
return control_predecessors_;
}
const std::vector<HloInstruction*>& control_successors() const {
return control_successors_;
}
// Returns true if "other" performs the same computation as this instruction.
bool Identical(
const HloInstruction& other,
const std::function<bool(const HloInstruction*, const HloInstruction*)>&
eq_operands = std::equal_to<const HloInstruction*>(),
const std::function<bool(const HloComputation*, const HloComputation*)>&
eq_computations = std::equal_to<const HloComputation*>(),
bool layout_sensitive = true) const {
// An instruction is always identical to itself.
if (this == &other) {
return true;
}
// Identical instruction must have the same opcode, shape, and identical
// operands.
if (opcode() != other.opcode()) {
return false;
}
if (!(layout_sensitive ? ShapeUtil::Equal(shape(), other.shape())
: ShapeUtil::Compatible(shape(), other.shape()))) {
return false;
}
if (operands().size() != other.operands().size()) {
return false;
}
// Use an explicit loop rather than ContainerEquals, because copying around
// std::functions may be too expensive in some cases.
for (size_t i = 0; i < operands().size(); ++i) {
if (!eq_operands(operand(i), other.operand(i))) {
return false;
}
}
if (backend_config_ != other.backend_config_) {
return false;
}
return IdenticalSlowPath(other, eq_computations);
}
// Returns whether the instruction has a constant operand.
bool HasConstantOperand() const;
// Returns whether this instruction does a rank-2 transposition.
bool IsRank2Transpose() const;
// Replaces the use of this instruction in "user" with "new_producer". Note
// that there might be multiple uses of this instruction in "user"; all will
// be replaced.
Status ReplaceUseWith(HloInstruction* user, HloInstruction* new_producer);
// Replaces the specified operand with new_operand.
Status ReplaceOperandWith(int64 operand_no, HloInstruction* new_operand);
// Replaces all uses of this instruction with the new producer. If
// new_producer is a user of this instruction then new_producer remains a use
// of this instruction to avoid introducing cycles into the graph.
//
// If this instruction is the root of its computation, sets the computation's
// root to new_producer.
Status ReplaceAllUsesWith(HloInstruction* new_producer);
// Detaches an instruction from its operands. That is, remove the instruction
// from each operand's user set. This should only be called prior to
// deallocating the instruction.
//
// TODO(b/78305363): Make this automatic when deleting an instruction.
void DetachFromOperands();
// Performs a postorder DFS visit using this node as the root. If
// call_finish_visit is true, then DfsHloVisitor::FinishVisit is called when
// complete. If ignore_control_predecessors is true, instructions only
// reachable via control dependencies will not be visited, and the postorder
// will not take control dependencies into account. It is as if the control
// dependencies didn't exist in the graph at all.
template <typename HloInstructionPtr>
Status Accept(DfsHloVisitorBase<HloInstructionPtr>* visitor,
bool call_finish_visit = true,
bool ignore_control_predecessors = false);
Status Accept(ConstDfsHloVisitor* visitor, bool call_finish_visit = true,
bool ignore_control_predecessors = false) const {
return const_cast<HloInstruction*>(this)->Accept(
visitor, call_finish_visit, ignore_control_predecessors);
}
// Same as Accept() above, but the order of operand and control predecessor
// visitation is determined by the given operand order; if compare(A, B) ==
// true, A is visited before B.
using CompareFunction =
std::function<bool(const HloInstruction*, const HloInstruction*)>;
Status AcceptWithOperandOrder(DfsHloVisitor* visitor,
const CompareFunction& operand_order,
bool call_finish_visit = true);
// Performs a postorder DFS visit using this node as the root. Calls the given
// visitor function at each instruction.
Status Accept(const std::function<Status(HloInstruction*)>& visitor_func);
Status Accept(
const std::function<Status(const HloInstruction*)>& visitor_func) const;
// Visits all instructions rooted at this instruction using the given visitor
// in the given order. 'order' must contain at least the set of instructions
// rooted at this node (ie, those accessible from a DFS traversal from this
// instruction). Instructions contained in 'order' which are not in the set of
// instructions rooted at this node are ignored. 'order' must also be a valid
// topological sort of these instructions (defs appear before uses) though
// need not be a DFS post-order.
Status AcceptOrdered(DfsHloVisitor* visitor,
const std::vector<const HloInstruction*>& order);
// Visit this instruction and only this instruction with the given visitor.
template <typename HloInstructionPtr>
Status Visit(DfsHloVisitorBase<HloInstructionPtr>* visitor);
// Returns the literal associated with this instruction.
//
// Note: only constant and parameter opcodes have an associated literal.
const Literal& literal() const;
// Returns whether there is literal associated with this instruction.
bool HasLiteral() const;
// Returns the parameter number associated with this instruction.
//
// Note: only parameter opcodes have an associated parameter number.
int64 parameter_number() const {
CHECK_EQ(HloOpcode::kParameter, opcode_);
return parameter_number_;
}
// Returns the dimension sizes or numbers associated with this instruction.
//
// Precondition: opcode() is one of: concatenate, reduce, broadcast, reshape,
// and reverse.
const std::vector<int64>& dimensions() const;
int64 dimensions(int64 index) const;
// Accessor for the dimension in which a concatenate HLO should occur.
// Precondition: opcode() == HloOpcode::kConcatenate
int64 concatenate_dimension() const;
// Returns the tuple index associated with this instruction.
//
// Precondition: opcode() == HloOpcode::kGetTupleElement
int64 tuple_index() const;
// Returns the first non-GetTupleElement ancestor instruction of 'hlo'.
// If the first non-GTE ancestor is tuple-shaped, populates 'index' with the
// (possibly nested) tuple indices used on the path from ancestor to 'hlo'.
std::pair<const HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex()
const;
std::pair<HloInstruction*, ShapeIndex> LatestNonGteAncestorAndIndex() {
auto rv =
const_cast<const HloInstruction*>(this)->LatestNonGteAncestorAndIndex();
return {const_cast<HloInstruction*>(rv.first), rv.second};
}
// Same as LatestNonGteAncestorAndIndex, but just returns the HloInstruction.
const HloInstruction* LatestNonGteAncestor() const;
HloInstruction* LatestNonGteAncestor() {
return const_cast<HloInstruction*>(
const_cast<const HloInstruction*>(this)->LatestNonGteAncestor());
}
// Gets/sets the to_apply HloComputation for Call, Map, Reduce, etc.
// The setter should only be called by HloModule or HloComputation methods.
//
// Precondition: The instruction has a valid to_apply_ field.
HloComputation* to_apply() const;
void set_to_apply(HloComputation* to_apply);
// Returns the custom_call_target for CustomCall.
// Precondition: opcode() == HloOpcode::kCustomCall
const string& custom_call_target() const;
// Returns the config for the Outfeed instruction.
// Precondition: opcode() == HloOpcode::kOutfeed
const string& outfeed_config() const;
// Returns the shape for the Outfeed instruction.
// Precondition: opcode() == HloOpcode::kOutfeed
const Shape& outfeed_shape() const;
// Gets/sets the while_condition or while_body HloComputation for While. The
// setters should only be called by HloModule or HloComputation methods.
//
// Precondition: The instruction is a While instruction.
HloComputation* while_condition() const;
HloComputation* while_body() const;
void set_while_condition(HloComputation* while_condition);
void set_while_body(HloComputation* while_body);
// Gets/sets the select or scatter HloComputation for SelectAndScatter. The
// setters should only be called by HloModule or HloComputation methods.
//
// Precondition: opcode() == HloOpcode::kSelectAndScatter.
HloComputation* select() const;
HloComputation* scatter() const;
void set_select(HloComputation* select);
void set_scatter(HloComputation* scatter);
// Gets/sets the true and false HloComputation for Conditional. The setters
// should only be called by HloModule or HloComputation methods.
//
// Precondition: The instruction is a Conditional instruction.
HloComputation* true_computation() const;
HloComputation* false_computation() const;
void set_true_computation(HloComputation* true_computation);
void set_false_computation(HloComputation* false_computation);
// Returns a string for the signature of this instruction if considered as a
// function, e.g. the signature of an F32 add is (F32, F32) -> F32.
string SignatureString() const;
// Returns a debugging string that represents this instruction.
//
// (We express the default options using an overload rather than a default
// param because gdb ignores default params, but does resolve overloads.)
//
// TODO(b/73348663): Make ToString() adaptive to the size of the string by
// default, backing off on providing full information for very large strings,
// or provide a different name for a ToString-like function that does that.
string ToString() const { return ToString(HloPrintOptions()); }
string ToString(const HloPrintOptions& options) const;
// Components of the ToString() representation:
// Returns a string representation of the operand list.
string OperandsToString(const HloPrintOptions& options) const;
// Returns string representation of op-specific attributes.
std::vector<string> ExtraAttributesToString(
const HloPrintOptions& options) const;
// As ToString, but returns a shorter string.
string ToShortString() const;
// Returns a serialized representation of this instruction.
HloInstructionProto ToProto() const;
// Returns a category for the HLO. This could be something like "convolution"
// or "elementwise".
string ToCategory() const;
// Returns a logging instruction, if the output of this instruction is logged.
//
// Postcondition: retval == nullptr || retval->opcode() == HloOpcode::kTrace
HloInstruction* tracing() const;
void set_tracing(HloInstruction* trace_instruction);
// Returns the channel id associated with the instruction. The id is
// shared between each Send/Recv pair and is globally unique to identify each
// channel.