/
xla_builder.cc
1974 lines (1709 loc) · 75.1 KB
/
xla_builder.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2018 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/client/xla_client/xla_builder.h"
#include <functional>
#include <numeric>
#include <queue>
#include <string>
#include <utility>
#include "tensorflow/compiler/xla/execution_options_util.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/shape_inference.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/core/lib/gtl/flatset.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/platform/mutex.h"
namespace xla {
using tensorflow::strings::StrCat;
namespace {
int64 GetUniqueId() {
static tensorflow::mutex mu(tensorflow::LINKER_INITIALIZED);
static int64 built_counter = 0;
tensorflow::mutex_lock loc(mu);
const int64 id = built_counter++;
return id;
}
// Returns true if an instruction with the given opcode can be the root of the
// computation.
bool CanBeRoot(HloOpcode opcode) {
switch (opcode) {
case HloOpcode::kSend:
case HloOpcode::kSendDone:
case HloOpcode::kOutfeed:
case HloOpcode::kTrace:
return false;
default:
return true;
}
}
} // namespace
StatusOr<Shape> XlaBuilder::GetShape(const XlaOp& op) const {
TF_RETURN_IF_ERROR(first_error_);
TF_ASSIGN_OR_RETURN(auto instr, LookUpInstruction(op));
return instr->shape();
}
StatusOr<std::vector<Shape>> XlaBuilder::GetOperandShapes(
tensorflow::gtl::ArraySlice<XlaOp> operands) const {
std::vector<Shape> operand_shapes;
for (const XlaOp& operand : operands) {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
operand_shapes.push_back(shape);
}
return operand_shapes;
}
XlaBuilder::XlaBuilder(const string& computation_name)
: name_(computation_name) {}
XlaBuilder::~XlaBuilder() {}
void XlaBuilder::NoteError(const Status& error) {
CHECK(!error.ok());
if (die_immediately_on_error_) {
LOG(FATAL) << "error building computation: " << error;
}
if (first_error_.ok()) {
first_error_ = error;
first_error_backtrace_.CreateCurrent(/*skip_count=*/1);
}
}
XlaOp XlaBuilder::NoteErrorOrReturn(
const std::function<StatusOr<XlaOp>()>& op_creator) {
if (!first_error_.ok()) {
return {};
}
auto op = op_creator();
if (!op.ok()) {
NoteError(op.status());
return {};
}
return op.ConsumeValueOrDie();
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape(int64* root_id) const {
TF_RETURN_IF_ERROR(first_error_);
TF_RET_CHECK(root_id != nullptr);
ProgramShape program_shape;
// Not all instructions can be roots. Walk backwards from the last added
// instruction until a valid root is found.
int64 index = instructions_.size() - 1;
for (; index >= 0; index--) {
TF_ASSIGN_OR_RETURN(HloOpcode opcode,
StringToHloOpcode(instructions_[index].opcode()));
if (CanBeRoot(opcode)) {
break;
}
}
if (index < 0) {
return FailedPrecondition("no root instruction was found");
}
*root_id = instructions_[index].id();
*program_shape.mutable_result() = instructions_[index].shape();
// Check that the parameter numbers are continuous from 0, and add parameter
// shapes and names to the program shape.
const int64 param_count = parameter_numbers_.size();
for (int64 i = 0; i < param_count; i++) {
program_shape.add_parameters();
program_shape.add_parameter_names();
}
for (const HloInstructionProto& instr : instructions_) {
// Parameter number uniqueness is guaranteed in XlaBuilder::Parameter(). So
// to verify continuity, we just need to verify that every parameter is in
// the right range.
if (instr.opcode() == HloOpcodeString(HloOpcode::kParameter)) {
const int64 index = instr.parameter_number();
TF_RET_CHECK(index >= 0 && index < param_count)
<< "invalid parameter number: " << index;
*program_shape.mutable_parameters(index) = instr.shape();
*program_shape.mutable_parameter_names(index) = instr.name();
}
}
return program_shape;
}
StatusOr<ProgramShape> XlaBuilder::GetProgramShape() const {
int64 root;
return GetProgramShape(&root);
}
void XlaBuilder::IsConstantVisitor(const int64 op_handle,
std::set<int64>* visited,
bool* is_constant) const {
if (visited->count(op_handle) != 0 || !*is_constant) {
return;
}
CHECK(op_handle < instructions_.size() && op_handle >= 0);
const HloInstructionProto& instr = instructions_[op_handle];
const HloOpcode opcode = StringToHloOpcode(instr.opcode()).ValueOrDie();
switch (opcode) {
default:
for (const int64 operand_id : instr.operand_ids()) {
IsConstantVisitor(operand_id, visited, is_constant);
}
// TODO(b/32495713): We aren't checking the called computations.
break;
// Non functional ops.
case HloOpcode::kRng:
case HloOpcode::kCrossReplicaSum:
// TODO(b/33009255): Implmement constant folding for cross replica sum.
case HloOpcode::kInfeed:
case HloOpcode::kOutfeed:
case HloOpcode::kHostCompute:
case HloOpcode::kCall:
// TODO(b/32495713): We aren't checking the to_apply computation itself,
// so we conservatively say that computations containing the Call op
// cannot be constant. We cannot set is_functional=false in other similar
// cases since we're already relying on IsConstant to return true.
case HloOpcode::kCustomCall:
case HloOpcode::kWhile:
// TODO(b/32495713): We aren't checking the condition and body
// computations themselves.
case HloOpcode::kSend:
case HloOpcode::kRecv:
case HloOpcode::kParameter:
*is_constant = false;
break;
}
if (!*is_constant) {
VLOG(1) << "Non-constant: " << instr.name();
}
visited->insert(op_handle);
}
XlaComputation XlaBuilder::BuildAndNoteError() {
DCHECK(parent_builder_ != nullptr);
auto build_status = Build();
if (!build_status.ok()) {
parent_builder_->NoteError(
AddStatus(build_status.status(),
tensorflow::strings::StrCat("error from: ", name_)));
return {};
}
return build_status.ConsumeValueOrDie();
}
StatusOr<XlaComputation> XlaBuilder::Build() {
if (!first_error_.ok()) {
string backtrace;
first_error_backtrace_.Dump(tensorflow::DebugWriteToString, &backtrace);
return AppendStatus(first_error_, backtrace);
}
HloComputationProto entry;
entry.set_id(GetUniqueId()); // Give the computation a global unique id.
entry.set_name(StrCat(name_, entry.id())); // Ensure that the name is unique.
{
int64 root_id;
TF_ASSIGN_OR_RETURN(*entry.mutable_program_shape(),
GetProgramShape(&root_id));
entry.set_root_id(root_id);
}
for (auto& instruction : instructions_) {
// Ensures that the instruction names are unique among the whole graph.
const string& new_name =
StrCat(instruction.name(), ".", entry.id(), ".", instruction.id());
instruction.set_name(new_name);
entry.add_instructions()->Swap(&instruction);
}
XlaComputation computation(entry.id());
HloModuleProto* module = computation.mutable_proto();
module->set_name(entry.name());
module->set_id(entry.id());
module->set_entry_computation_name(entry.name());
module->set_entry_computation_id(entry.id());
*module->mutable_program_shape() = entry.program_shape();
for (auto& e : embedded_) {
module->add_computations()->Swap(&e.second);
}
module->add_computations()->Swap(&entry);
// Clear data held by this builder.
this->instructions_.clear();
this->embedded_.clear();
this->parameter_numbers_.clear();
return std::move(computation);
}
StatusOr<XlaOp> XlaBuilder::InDimBroadcast(
const Shape& shape, const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape;
for (int64 dim : broadcast_dimensions) {
instr.add_dimensions(dim);
}
return AddInstruction(std::move(instr), HloOpcode::kBroadcast, {operand});
}
StatusOr<XlaOp> XlaBuilder::AddBroadcastSequence(const Shape& output_shape,
const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
CHECK(ShapeUtil::IsScalar(operand_shape) ||
ShapeUtil::Rank(operand_shape) == ShapeUtil::Rank(output_shape));
Shape broadcast_shape =
ShapeUtil::ChangeElementType(output_shape, operand_shape.element_type());
// Do explicit broadcast for scalar.
if (ShapeUtil::IsScalar(operand_shape)) {
return InDimBroadcast(broadcast_shape, operand, {});
}
// Do explicit broadcast for degenerate broadcast.
std::vector<int64> broadcast_dimensions;
std::vector<int64> reshaped_dimensions;
for (int i = 0; i < ShapeUtil::Rank(operand_shape); i++) {
if (operand_shape.dimensions(i) == output_shape.dimensions(i)) {
broadcast_dimensions.push_back(i);
reshaped_dimensions.push_back(operand_shape.dimensions(i));
} else {
TF_RET_CHECK(operand_shape.dimensions(i) == 1)
<< "An explicit broadcast sequence requires the broadcasted "
"dimensions to be trivial; operand shape: "
<< operand_shape << "; output_shape: " << output_shape;
}
}
// Eliminate the size one dimensions.
TF_ASSIGN_OR_RETURN(XlaOp reshaped_operand,
Reshape(ShapeUtil::MakeShape(operand_shape.element_type(),
reshaped_dimensions),
operand));
// Broadcast 'reshape' up to the larger size.
return InDimBroadcast(broadcast_shape, reshaped_operand,
broadcast_dimensions);
}
XlaOp XlaBuilder::UnaryOp(HloOpcode unop, const XlaOp& operand) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferUnaryOpShape(unop, operand_shape));
return AddInstruction(std::move(instr), unop, {operand});
});
}
XlaOp XlaBuilder::BinaryOp(
HloOpcode binop, const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferBinaryOpShape(
binop, lhs_shape, rhs_shape, broadcast_dimensions));
const int64 lhs_rank = ShapeUtil::Rank(lhs_shape);
const int64 rhs_rank = ShapeUtil::Rank(rhs_shape);
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
if (!broadcast_dimensions.empty() && lhs_rank != rhs_rank) {
const bool should_broadcast_lhs = lhs_rank < rhs_rank;
XlaOp from = should_broadcast_lhs ? lhs : rhs;
const Shape& from_shape = should_broadcast_lhs ? lhs_shape : rhs_shape;
std::vector<int64> to_size;
for (int64 size : instr.shape().dimensions()) {
to_size.push_back(size);
}
for (int64 from_dim = 0; from_dim < ShapeUtil::Rank(from_shape);
from_dim++) {
int64 to_dim = broadcast_dimensions[from_dim];
to_size[to_dim] = from_shape.dimensions(from_dim);
}
const Shape& broadcasted_shape =
ShapeUtil::MakeShape(from_shape.element_type(), to_size);
TF_ASSIGN_OR_RETURN(
XlaOp broadcasted_operand,
InDimBroadcast(broadcasted_shape, from, broadcast_dimensions));
updated_lhs = should_broadcast_lhs ? broadcasted_operand : lhs;
updated_rhs = !should_broadcast_lhs ? broadcasted_operand : rhs;
}
TF_ASSIGN_OR_RETURN(Shape updated_lhs_shape, GetShape(updated_lhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_lhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(instr.shape(), updated_lhs));
}
TF_ASSIGN_OR_RETURN(Shape updated_rhs_shape, GetShape(updated_rhs));
if (!ShapeUtil::SameDimensions(instr.shape(), updated_rhs_shape)) {
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(instr.shape(), updated_rhs));
}
return AddInstruction(std::move(instr), binop, {updated_lhs, updated_rhs});
});
}
XlaOp XlaBuilder::TernaryOp(HloOpcode triop, const XlaOp& lhs, const XlaOp& rhs,
const XlaOp& ehs) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(const Shape& ehs_shape, GetShape(ehs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferTernaryOpShape(
triop, lhs_shape, rhs_shape, ehs_shape));
XlaOp updated_lhs = lhs;
XlaOp updated_rhs = rhs;
XlaOp updated_ehs = ehs;
if (!ShapeUtil::IsTuple(instr.shape())) {
if (!ShapeUtil::IsTuple(lhs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), lhs_shape)) {
// lhs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_lhs,
AddBroadcastSequence(instr.shape(), lhs));
}
if (!ShapeUtil::IsTuple(rhs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), rhs_shape)) {
// rhs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_rhs,
AddBroadcastSequence(instr.shape(), rhs));
}
if (!ShapeUtil::IsTuple(ehs_shape) &&
!ShapeUtil::SameDimensions(instr.shape(), ehs_shape)) {
// ehs is being implicitly broadcasted. Change to explicit.
TF_ASSIGN_OR_RETURN(updated_ehs,
AddBroadcastSequence(instr.shape(), ehs));
}
}
return AddInstruction(std::move(instr), triop,
{updated_lhs, updated_rhs, updated_ehs});
});
}
XlaOp XlaBuilder::Add(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kAdd, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Mul(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kMultiply, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::ConstantLiteral(const LiteralSlice& literal) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = literal.shape();
*instr.mutable_literal() = literal.ToProto();
return AddInstruction(std::move(instr), HloOpcode::kConstant);
});
}
XlaOp XlaBuilder::Call(const XlaComputation& computation,
tensorflow::gtl::ArraySlice<XlaOp> operands) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(const ProgramShape& called_program_shape,
computation.GetProgramShape());
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferCallShape(operand_shape_ptrs,
/*to_apply=*/called_program_shape));
AddCalledComputation(computation, &instr);
return AddInstruction(std::move(instr), HloOpcode::kCall, operands);
});
}
XlaOp XlaBuilder::Parameter(int64 parameter_number, const Shape& shape,
const string& name) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
if (!parameter_numbers_.insert(parameter_number).second) {
return InvalidArgument("parameter %lld already registered",
parameter_number);
}
instr.set_parameter_number(parameter_number);
instr.set_name(name);
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr), HloOpcode::kParameter);
});
}
XlaOp XlaBuilder::Broadcast(
const XlaOp& operand, tensorflow::gtl::ArraySlice<int64> broadcast_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
const Shape& shape,
ShapeInference::InferBroadcastShape(operand_shape, broadcast_sizes));
// The client-level broadcast op just appends dimensions on the left (adds
// lowest numbered dimensions). The HLO broadcast instruction is more
// flexible and can add new dimensions anywhere. The instruction's
// dimensions field maps operand dimensions to dimensions in the broadcast
// output, so to append dimensions on the left the instruction's dimensions
// should just be the n highest dimension numbers of the output shape where
// n is the number of input dimensions.
const int64 operand_rank = ShapeUtil::Rank(operand_shape);
std::vector<int64> dimensions(operand_rank);
for (int i = 0; i < operand_rank; ++i) {
dimensions[i] = i + ShapeUtil::Rank(shape) - operand_rank;
}
return InDimBroadcast(shape, operand, dimensions);
});
}
StatusOr<XlaOp> XlaBuilder::Reshape(const Shape& shape, const XlaOp& operand) {
TF_RETURN_IF_ERROR(first_error_);
HloInstructionProto instr;
*instr.mutable_shape() = shape;
return AddInstruction(std::move(instr), HloOpcode::kReshape, {operand});
}
XlaOp XlaBuilder::Slice(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> start_indices,
tensorflow::gtl::ArraySlice<int64> limit_indices,
tensorflow::gtl::ArraySlice<int64> strides) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferSliceShape(operand_shape, start_indices,
limit_indices, strides));
for (int i = 0; i < start_indices.size(); i++) {
auto* slice_config = instr.add_slice_dimensions();
slice_config->set_start(start_indices[i]);
slice_config->set_limit(limit_indices[i]);
slice_config->set_stride(strides[i]);
}
return AddInstruction(std::move(instr), HloOpcode::kSlice, {operand});
});
}
XlaOp XlaBuilder::SliceInDim(const XlaOp& operand, int64 start_index,
int64 limit_index, int64 stride, int64 dimno) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& shape, GetShape(operand));
std::vector<int64> starts(ShapeUtil::Rank(shape), 0);
std::vector<int64> limits(shape.dimensions().begin(),
shape.dimensions().end());
std::vector<int64> strides(ShapeUtil::Rank(shape), 1);
starts[dimno] = start_index;
limits[dimno] = limit_index;
strides[dimno] = stride;
return Slice(operand, starts, limits, strides);
});
}
XlaOp XlaBuilder::DynamicSlice(const XlaOp& operand, const XlaOp& start_indices,
tensorflow::gtl::ArraySlice<int64> slice_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
GetShape(start_indices));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferDynamicSliceShape(
operand_shape, start_indices_shape, slice_sizes));
for (int64 size : slice_sizes) {
instr.add_dynamic_slice_sizes(size);
}
return AddInstruction(std::move(instr), HloOpcode::kDynamicSlice,
{operand, start_indices});
});
}
XlaOp XlaBuilder::DynamicUpdateSlice(const XlaOp& operand, const XlaOp& update,
const XlaOp& start_indices) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& update_shape, GetShape(update));
TF_ASSIGN_OR_RETURN(const Shape& start_indices_shape,
GetShape(start_indices));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferDynamicUpdateSliceShape(
operand_shape, update_shape, start_indices_shape));
return AddInstruction(std::move(instr), HloOpcode::kDynamicUpdateSlice,
{operand, update, start_indices});
});
}
XlaOp XlaBuilder::ConcatInDim(tensorflow::gtl::ArraySlice<XlaOp> operands,
int64 dimension) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(operands));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConcatOpShape(operand_shape_ptrs, dimension));
instr.add_dimensions(dimension);
return AddInstruction(std::move(instr), HloOpcode::kConcatenate, operands);
});
}
XlaOp XlaBuilder::Pad(const XlaOp& operand, const XlaOp& padding_value,
const PaddingConfig& padding_config) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& padding_value_shape,
GetShape(padding_value));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferPadShape(operand_shape, padding_value_shape,
padding_config));
*instr.mutable_padding_config() = padding_config;
return AddInstruction(std::move(instr), HloOpcode::kPad,
{operand, padding_value});
});
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(const Shape& shape,
ShapeInference::InferReshapeShape(
operand_shape, dimensions, new_sizes));
XlaOp transposed = IsIdentityPermutation(dimensions)
? operand
: Transpose(operand, dimensions);
return Reshape(shape, transposed);
});
}
XlaOp XlaBuilder::Reshape(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> new_sizes) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(auto shape, GetShape(operand));
std::vector<int64> dimensions(shape.dimensions_size());
std::iota(dimensions.begin(), dimensions.end(), 0);
return Reshape(operand, dimensions, new_sizes);
});
}
XlaOp XlaBuilder::Collapse(const XlaOp& operand,
tensorflow::gtl::ArraySlice<int64> dimensions) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
if (dimensions.size() <= 1) {
// Not collapsing anything, trivially we can return the operand versus
// enqueueing a trivial reshape.
return operand;
}
// Out-of-order collapse is not supported.
// Checks that the collapsed dimensions are in order and consecutive.
for (tensorflow::gtl::ArraySlice<int64>::size_type i = 1;
i < dimensions.size(); ++i) {
if (dimensions[i] - 1 != dimensions[i - 1]) {
return InvalidArgument(
"Collapsed dimensions are not in consecutive order.");
}
}
// Create a new sizes vector from the old shape, replacing the collapsed
// dimensions by the product of their sizes.
TF_ASSIGN_OR_RETURN(const Shape& original_shape, GetShape(operand));
VLOG(3) << "original shape: " << ShapeUtil::HumanString(original_shape);
VLOG(3) << "dims to collapse: "
<< tensorflow::str_util::Join(dimensions, ",");
std::vector<int64> new_sizes;
for (int i = 0; i < ShapeUtil::Rank(original_shape); ++i) {
if (i <= dimensions.front() || i > dimensions.back()) {
new_sizes.push_back(original_shape.dimensions(i));
} else {
new_sizes.back() *= original_shape.dimensions(i);
}
}
VLOG(3) << "new sizes: [" << tensorflow::str_util::Join(new_sizes, ",")
<< "]";
return Reshape(operand, new_sizes);
});
}
void XlaBuilder::Trace(const string& tag, const XlaOp& operand) {
NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
*instr.mutable_shape() = ShapeUtil::MakeNil();
*instr.mutable_literal() = Literal::CreateR1U8(tag)->ToProto();
return AddInstruction(std::move(instr), HloOpcode::kTrace, {operand});
});
}
XlaOp XlaBuilder::Select(const XlaOp& pred, const XlaOp& on_true,
const XlaOp& on_false) {
return TernaryOp(HloOpcode::kSelect, pred, on_true, on_false);
}
XlaOp XlaBuilder::Tuple(tensorflow::gtl::ArraySlice<XlaOp> elements) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
std::vector<const Shape*> operand_shape_ptrs;
TF_ASSIGN_OR_RETURN(const auto& operand_shapes, GetOperandShapes(elements));
c_transform(operand_shapes, std::back_inserter(operand_shape_ptrs),
[](const Shape& shape) { return &shape; });
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferVariadicOpShape(
HloOpcode::kTuple, operand_shape_ptrs));
return AddInstruction(std::move(instr), HloOpcode::kTuple, elements);
});
}
XlaOp XlaBuilder::GetTupleElement(const XlaOp& tuple_data, int64 index) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& tuple_shape, GetShape(tuple_data));
if (!ShapeUtil::IsTuple(tuple_shape)) {
return InvalidArgument(
"Operand to GetTupleElement() is not a tuple; got %s",
ShapeUtil::HumanString(tuple_shape).c_str());
}
*instr.mutable_shape() =
ShapeUtil::GetTupleElementShape(tuple_shape, index);
instr.set_tuple_index(index);
return AddInstruction(std::move(instr), HloOpcode::kGetTupleElement,
{tuple_data});
});
}
XlaOp XlaBuilder::Eq(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kEq, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ne(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kNe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Ge(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Gt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kGt, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Le(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLe, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Lt(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> broadcast_dimensions) {
return BinaryOp(HloOpcode::kLt, lhs, rhs, broadcast_dimensions);
}
XlaOp XlaBuilder::Dot(const XlaOp& lhs, const XlaOp& rhs) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
DotDimensionNumbers dimension_numbers;
dimension_numbers.add_lhs_contracting_dimensions(
lhs_shape.dimensions_size() == 1 ? 0 : 1);
dimension_numbers.add_rhs_contracting_dimensions(0);
return DotGeneral(lhs, rhs, dimension_numbers);
});
}
XlaOp XlaBuilder::DotGeneral(const XlaOp& lhs, const XlaOp& rhs,
const DotDimensionNumbers& dimension_numbers) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_ASSIGN_OR_RETURN(*instr.mutable_shape(),
ShapeInference::InferDotOpShape(lhs_shape, rhs_shape,
dimension_numbers));
*instr.mutable_dot_dimension_numbers() = dimension_numbers;
return AddInstruction(std::move(instr), HloOpcode::kDot, {lhs, rhs});
});
}
Status XlaBuilder::VerifyConvolution(
const Shape& lhs_shape, const Shape& rhs_shape,
const ConvolutionDimensionNumbers& dimension_numbers) const {
if (ShapeUtil::Rank(lhs_shape) != ShapeUtil::Rank(rhs_shape)) {
return InvalidArgument(
"Convolution arguments must have same number of "
"dimensions. Got: %s and %s",
ShapeUtil::HumanString(lhs_shape).c_str(),
ShapeUtil::HumanString(rhs_shape).c_str());
}
int num_dims = ShapeUtil::Rank(lhs_shape);
if (num_dims < 2) {
return InvalidArgument(
"Convolution expects argument arrays with >= 3 dimensions. "
"Got: %s and %s",
ShapeUtil::HumanString(lhs_shape).c_str(),
ShapeUtil::HumanString(rhs_shape).c_str());
}
int num_spatial_dims = num_dims - 2;
const auto check_spatial_dimensions =
[&](const char* const field_name,
const tensorflow::protobuf::RepeatedField<tensorflow::protobuf_int64>&
numbers) {
if (numbers.size() != num_spatial_dims) {
return InvalidArgument("Expected %d elements for %s, but got %d.",
num_spatial_dims, field_name, numbers.size());
}
for (int i = 0; i < numbers.size(); ++i) {
if (numbers.Get(i) < 0 || numbers.Get(i) >= num_dims) {
return InvalidArgument("Convolution %s[%d] is out of bounds: %lld",
field_name, i, numbers.Get(i));
}
}
return Status::OK();
};
TF_RETURN_IF_ERROR(
check_spatial_dimensions("input_spatial_dimensions",
dimension_numbers.input_spatial_dimensions()));
TF_RETURN_IF_ERROR(
check_spatial_dimensions("kernel_spatial_dimensions",
dimension_numbers.kernel_spatial_dimensions()));
return check_spatial_dimensions(
"output_spatial_dimensions",
dimension_numbers.output_spatial_dimensions());
}
XlaOp XlaBuilder::Conv(const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
Padding padding) {
return ConvWithGeneralDimensions(
lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()));
}
XlaOp XlaBuilder::ConvWithGeneralPadding(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding) {
return ConvGeneral(lhs, rhs, window_strides, padding,
CreateDefaultConvDimensionNumbers(window_strides.size()));
}
XlaOp XlaBuilder::ConvWithGeneralDimensions(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides, Padding padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_RETURN_IF_ERROR(
VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
std::vector<int64> base_area_dimensions(
dimension_numbers.input_spatial_dimensions_size());
for (std::vector<int64>::size_type i = 0; i < base_area_dimensions.size();
++i) {
base_area_dimensions[i] =
lhs_shape.dimensions(dimension_numbers.input_spatial_dimensions(i));
}
std::vector<int64> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
++i) {
window_dimensions[i] =
rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
return ConvGeneral(lhs, rhs, window_strides,
MakePadding(base_area_dimensions, window_dimensions,
window_strides, padding),
dimension_numbers);
});
}
XlaOp XlaBuilder::ConvGeneral(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
const ConvolutionDimensionNumbers& dimension_numbers) {
return ConvGeneralDilated(lhs, rhs, window_strides, padding, {}, {},
dimension_numbers);
}
XlaOp XlaBuilder::ConvGeneralDilated(
const XlaOp& lhs, const XlaOp& rhs,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation,
const ConvolutionDimensionNumbers& dimension_numbers) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& lhs_shape, GetShape(lhs));
TF_ASSIGN_OR_RETURN(const Shape& rhs_shape, GetShape(rhs));
TF_RETURN_IF_ERROR(
VerifyConvolution(lhs_shape, rhs_shape, dimension_numbers));
std::vector<int64> window_dimensions(
dimension_numbers.kernel_spatial_dimensions_size());
for (std::vector<int64>::size_type i = 0; i < window_dimensions.size();
++i) {
window_dimensions[i] =
rhs_shape.dimensions(dimension_numbers.kernel_spatial_dimensions(i));
}
TF_ASSIGN_OR_RETURN(*instr.mutable_window(),
MakeWindow(window_dimensions, window_strides, padding,
lhs_dilation, rhs_dilation));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferConvolveShape(lhs_shape, rhs_shape, instr.window(),
dimension_numbers));
*instr.mutable_convolution_dimension_numbers() = dimension_numbers;
return AddInstruction(std::move(instr), HloOpcode::kConvolution,
{lhs, rhs});
});
}
StatusOr<Window> XlaBuilder::MakeWindow(
tensorflow::gtl::ArraySlice<int64> window_dimensions,
tensorflow::gtl::ArraySlice<int64> window_strides,
tensorflow::gtl::ArraySlice<std::pair<int64, int64>> padding,
tensorflow::gtl::ArraySlice<int64> lhs_dilation,
tensorflow::gtl::ArraySlice<int64> rhs_dilation) const {
const auto verify_size = [&](const size_t x, const char* x_name) {
if (x == 0 || x == window_dimensions.size()) {
return Status::OK();
} else {
return InvalidArgument(
"%s", tensorflow::strings::StrCat(
"Window has different number of window dimensions than of ",
x_name,
"\nNumber of window dimensions: ", window_dimensions.size(),
"\nNumber of ", x_name, ": ", x, "\n")
.c_str());
}
};
TF_RETURN_IF_ERROR(verify_size(window_strides.size(), "window strides"));
TF_RETURN_IF_ERROR(verify_size(padding.size(), "padding entries"));
TF_RETURN_IF_ERROR(verify_size(lhs_dilation.size(), "lhs dilation factors"));
TF_RETURN_IF_ERROR(verify_size(rhs_dilation.size(), "rhs dilation factors"));
Window window;
for (size_t i = 0; i < window_dimensions.size(); i++) {
auto dim = window.add_dimensions();
dim->set_size(window_dimensions[i]);
if (!window_strides.empty()) {
dim->set_stride(window_strides[i]);
} else {
dim->set_stride(1);
}
if (!padding.empty()) {
dim->set_padding_low(padding[i].first);
dim->set_padding_high(padding[i].second);
} else {
dim->set_padding_low(0);
dim->set_padding_high(0);
}
if (!lhs_dilation.empty()) {
dim->set_base_dilation(lhs_dilation[i]);
} else {
dim->set_base_dilation(1);
}
if (!rhs_dilation.empty()) {
dim->set_window_dilation(rhs_dilation[i]);
} else {
dim->set_window_dilation(1);
}
dim->set_window_reversal(false);
}
return window;
}
XlaOp XlaBuilder::Fft(const XlaOp& operand, const FftType fft_type,
const tensorflow::gtl::ArraySlice<int64> fft_length) {
return NoteErrorOrReturn([&]() -> StatusOr<XlaOp> {
HloInstructionProto instr;
TF_ASSIGN_OR_RETURN(const Shape& operand_shape, GetShape(operand));
TF_ASSIGN_OR_RETURN(
*instr.mutable_shape(),
ShapeInference::InferFftShape(operand_shape, fft_type, fft_length));