/
ir_emitter.cc
2278 lines (2008 loc) · 99.5 KB
/
ir_emitter.cc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
/* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/
#include "tensorflow/compiler/xla/service/cpu/ir_emitter.h"
#include <stddef.h>
#include <stdint.h>
#include <algorithm>
#include <iterator>
#include <limits>
#include <memory>
#include <utility>
#include <vector>
#include "tensorflow/core/platform/logging.h"
// IWYU pragma: no_include "llvm/IR/Intrinsics.gen.inc"
#include "external/llvm/include/llvm/IR/BasicBlock.h"
#include "external/llvm/include/llvm/IR/Constants.h"
#include "external/llvm/include/llvm/IR/GlobalVariable.h"
#include "external/llvm/include/llvm/IR/Instructions.h"
#include "external/llvm/include/llvm/IR/Intrinsics.h"
#include "external/llvm/include/llvm/IR/LLVMContext.h"
#include "tensorflow/compiler/xla/layout_util.h"
#include "tensorflow/compiler/xla/map_util.h"
#include "tensorflow/compiler/xla/service/buffer_assignment.h"
#include "tensorflow/compiler/xla/service/cpu/cpu_runtime.h"
#include "tensorflow/compiler/xla/service/cpu/dot_op_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/cpu/ir_emission_utils.h"
#include "tensorflow/compiler/xla/service/cpu/simple_orc_jit.h"
#include "tensorflow/compiler/xla/service/elemental_ir_emitter.h"
#include "tensorflow/compiler/xla/service/hlo_opcode.h"
#include "tensorflow/compiler/xla/service/llvm_ir/fused_ir_emitter.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_loop.h"
#include "tensorflow/compiler/xla/service/llvm_ir/llvm_util.h"
#include "tensorflow/compiler/xla/service/llvm_ir/ops.h"
#include "tensorflow/compiler/xla/shape_util.h"
#include "tensorflow/compiler/xla/status_macros.h"
#include "tensorflow/compiler/xla/types.h"
#include "tensorflow/compiler/xla/util.h"
#include "tensorflow/compiler/xla/window_util.h"
#include "tensorflow/core/lib/core/errors.h"
#include "tensorflow/core/lib/gtl/array_slice.h"
#include "tensorflow/core/lib/strings/strcat.h"
#include "tensorflow/core/lib/strings/stringprintf.h"
namespace xla {
using llvm_ir::SetToFirstInsertPoint;
namespace cpu {
IrEmitter::IrEmitter(
const HloModule& hlo_module, const BufferAssignment& assignment,
llvm::Module* llvm_module,
const std::unordered_map<const HloInstruction*, size_t>* hlo_to_profile_idx)
: assignment_(assignment),
module_(llvm_module),
arch_type_(llvm::Triple(llvm_module->getTargetTriple()).getArch()),
ir_builder_(llvm_module->getContext()),
hlo_to_profile_idx_(hlo_to_profile_idx),
alias_analysis_(hlo_module, assignment, &llvm_module->getContext()),
hlo_module_config_(hlo_module.config()) {
ir_builder_.setFastMathFlags(llvm_ir::GetFastMathFlags(
/*fast_math_enabled=*/hlo_module_config_.debug_options()
.xla_enable_fast_math()));
}
StatusOr<llvm::Function*> IrEmitter::EmitComputation(
HloComputation* computation, const string& function_name_prefix,
bool is_entry_computation,
std::vector<const HloInstruction*>* instruction_order) {
string function_name = name_uniquer_.GetUniqueName(function_name_prefix);
VLOG(2) << "Emitting IR for CPU function [" << function_name_prefix
<< "]; ordered? " << (instruction_order != nullptr);
num_dynamic_loop_bounds_ = 0;
if (!computation->root_instruction()->outer_dimension_partitions().empty()) {
num_dynamic_loop_bounds_ =
computation->root_instruction()->outer_dimension_partitions().size();
}
InitializeIrFunction(function_name, is_entry_computation);
// The rdtscp instruction is x86 specific. We will fallback to LLVM's generic
// readcyclecounter if it is unavailable.
bool use_rdtscp = arch_type_ == llvm::Triple::ArchType::x86 ||
arch_type_ == llvm::Triple::ArchType::x86_64;
profiling_state_ = ProfilingState(is_entry_computation, use_rdtscp,
GetProfileCountersArgument());
if (instruction_order == nullptr) {
TF_RETURN_IF_ERROR(computation->Accept(this));
} else {
TF_RETURN_IF_ERROR(computation->AcceptOrdered(this, *instruction_order));
}
InsertOrDie(&emitted_functions_, computation, compute_function_);
return compute_function_;
}
static llvm::Argument* GetArg(llvm::Function* f, int idx) {
llvm::Function::arg_iterator arg_iter = f->arg_begin();
std::advance(arg_iter, idx);
return &*arg_iter;
}
void IrEmitter::InitializeIrFunction(const string& function_name,
bool is_entry_computation) {
// The function signature is:
// void function(i8* retval, i8* run_options, i8** params, i8** temps,
// i64* dynamic_loop_bounds, i64* prof_counters)
//
// retval: points to the returned value.
// params: address of an array with pointers to parameters.
// temps: address of an array with pointers to temporary buffers.
//
// Therefore, the generated function's signature (FunctionType) is statically
// determined - parameter unpacking is done in code generated into the
// function, rather than by a prologue dictated by the platform ABI.
//
// /--------------\
// retval ----------> | return value |
// \--------------/
//
// /-------------------------------\
// run_options -----> | xla::ExecutableRunOptions |
// \-------------------------------/
//
// /---------------------------------------------\
// params --------> | param 0 | param 1 | ..... | param N-1 |
// | addr | addr | | addr |
// \---------------------------------------------/
// | | |
// | | |
// V V V
// /---------\ /---------\ /-----------\
// | param 0 | | param 1 | | param N-1 |
// \---------/ \---------/ \-----------/
//
// /---------------------------------------------\
// temps ---------> | temp 0 | temp 1 | ..... | temp N-1 |
// | addr | addr | | addr |
// \---------------------------------------------/
// | | |
// | | |
// V V V
// /---------\ /---------\ /-----------\
// | temp 0 | | temp 1 | | temp N-1 |
// \---------/ \---------/ \-----------/
//
// /--------------------------------------------\
// dynamic loop bounds -> | outer_dim0_start | outer_dim0_limit | .....|
// (elided for aot) \--------------------------------------------/
//
// /---------------------------------------------\
// prof counters -> | counter 0 | counter 1 | ..... | counter N-1 |
// (elided for aot) \---------------------------------------------/
// Even though the type of params and temps is void** in the host's view, in
// LLVM IR this is represented by i8*, similarly to void*. It's up to the code
// to use GEPs to unravel the indirection layers.
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::Type* i8_ptr_ptr_type = i8_ptr_type->getPointerTo();
llvm::Type* i64_ptr_type = llvm::Type::getInt64PtrTy(module_->getContext());
std::vector<llvm::Type*> compute_function_params(
{i8_ptr_type, i8_ptr_type, i8_ptr_ptr_type, i8_ptr_ptr_type});
if (num_dynamic_loop_bounds_ > 0) {
compute_function_params.push_back(i64_ptr_type);
}
if (hlo_to_profile_idx_) {
compute_function_params.push_back(i64_ptr_type);
}
llvm::FunctionType* compute_function_type = llvm::FunctionType::get(
/*Result=*/llvm::Type::getVoidTy(module_->getContext()),
/*Params=*/compute_function_params,
/*isVarArg=*/false);
// Functions with local linkage get an inlining bonus. Because we know
// a-priori that embedded functions (non-entry functions) will not have its
// name resolved, give it local linkage.
llvm::Function::LinkageTypes linkage =
is_entry_computation ? llvm::GlobalValue::ExternalLinkage
: llvm::GlobalValue::InternalLinkage;
compute_function_ = llvm::Function::Create(/*Ty=*/compute_function_type,
/*Linkage=*/linkage,
/*Name=*/function_name.c_str(),
/*Module=*/module_);
compute_function_->setCallingConv(llvm::CallingConv::C);
// Set meaningful names for the function's arguments: useful for debugging.
llvm::Function::arg_iterator arg_iter = compute_function_->arg_begin();
arg_iter->setName("retval");
(++arg_iter)->setName("run_options");
(++arg_iter)->setName("params");
(++arg_iter)->setName("temps");
if (num_dynamic_loop_bounds_ > 0) {
(++arg_iter)->setName("dynamic_loop_bounds");
}
if (hlo_to_profile_idx_) {
(++arg_iter)->setName("prof_counters");
}
// We know a-priori that the function arguments are guaranteed to point to
// disjoint objects.
llvm::Argument* retval = GetResultArgument();
for (llvm::Argument& argument : compute_function_->args()) {
// However, the return buffer aliases the temporaries and thus cannot be
// marked noalias.
if (&argument == retval) {
continue;
}
compute_function_->addAttribute(argument.getArgNo() + 1,
llvm::Attribute::NoAlias);
}
ir_builder_.SetInsertPoint(llvm::BasicBlock::Create(
/*Context=*/module_->getContext(),
/*Name=*/"entry",
/*Parent=*/compute_function_));
}
IrEmitter::~IrEmitter() {}
Status IrEmitter::HandleBitcast(HloInstruction* bitcast) {
VLOG(2) << "HandleBitcast: " << bitcast->ToString();
emitted_value_[bitcast] = ir_builder_.CreateBitCast(
GetEmittedValueFor(bitcast->operand(0)),
IrShapeType(bitcast->shape())->getPointerTo(), bitcast->name().c_str());
return Status::OK();
}
Status IrEmitter::HandleConstant(HloInstruction* constant,
const Literal& literal) {
VLOG(2) << "HandleConstant: " << constant->ToString();
llvm::Constant* initializer =
llvm_ir::ConvertLiteralToIrConstant(literal, &ir_builder_);
llvm::GlobalVariable* global_for_const = new llvm::GlobalVariable(
/*Module=*/*module_,
/*Type=*/initializer->getType(),
/*isConstant=*/true,
/*Linkage=*/llvm::GlobalValue::PrivateLinkage,
/*Initializer=*/initializer,
/*Name=*/"");
emitted_value_[constant] = global_for_const;
VLOG(2) << " emitted value: " << llvm_ir::DumpToString(*global_for_const);
VLOG(2) << " its type: "
<< llvm_ir::DumpToString(*global_for_const->getType());
return Status::OK();
}
Status IrEmitter::HandleCopy(HloInstruction* copy) {
if (ShapeUtil::IsTuple(copy->shape())) {
// kCopy shallow copies a tuple so just memcpy the top-level buffer.
TF_ASSIGN_OR_RETURN(llvm::Value * copy_value, EmitTargetAddressForOp(copy));
emitted_value_[copy] = copy_value;
return EmitMemcpy(*(copy->operand(0)), *copy);
} else {
// Use the elemental emitter for non-tuple shapes.
return DefaultAction(copy);
}
}
// Calculate the alignment of a buffer with a particular size.
int IrEmitter::MinimumAlignmentForBufferSize(int64 buffer_size) {
// GLibc returns a pointer with alignment 8 on 32-bit platforms and 16 on
// 64-bit platforms. TCMalloc returns a pointer with alignment 8 for
// allocations smaller than 16 bytes and at least alignment 16 for allocations
// greater than or equal to 16 bytes. N.B. We could improve on this lower
// bound by explicitly allocating the memory with posix_memalign. This is
// complicated by our desire to allow parameter buffers created by clients to
// be consumed directly by the JIT.
if (buffer_size == 0) {
// No need to align empty buffers.
return 1;
}
int pointer_size = module_->getDataLayout().getPointerSize();
int buffer_alignment = buffer_size >= 16 ? 2 * pointer_size : 8;
DCHECK_GT(buffer_alignment, 0);
return buffer_alignment;
}
// Calculate the alignment of a buffer allocated for a given primitive type.
int IrEmitter::MinimumAlignmentForPrimitiveType(PrimitiveType primitive_type) {
int64 buffer_size = ShapeUtil::ByteSizeOfPrimitiveType(primitive_type);
DCHECK_GE(buffer_size, 0);
DCHECK_LE(buffer_size, SIZE_MAX);
return MinimumAlignmentForBufferSize(buffer_size);
}
int64 IrEmitter::ByteSizeOf(const Shape& shape) const {
return llvm_ir::ByteSizeOf(shape, module_->getDataLayout());
}
// Calculate the alignment of a buffer allocated for a given shape.
int IrEmitter::MinimumAlignmentForShape(const Shape& shape) {
int64 buffer_size = ByteSizeOf(shape);
DCHECK_GE(buffer_size, 0);
DCHECK_LE(buffer_size, SIZE_MAX);
return MinimumAlignmentForBufferSize(buffer_size);
}
void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
const Shape& shape) {
int alignment = MinimumAlignmentForShape(shape);
if (alignment > 1) {
llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
}
}
void IrEmitter::AttachAlignmentMetadataForLoad(llvm::LoadInst* load,
int64 buffer_size) {
int alignment = MinimumAlignmentForBufferSize(buffer_size);
if (alignment > 1) {
llvm_ir::SetAlignmentMetadataForLoad(load, alignment);
}
}
void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
const Shape& shape) {
AttachDereferenceableMetadataForLoad(load, ByteSizeOf(shape));
}
void IrEmitter::AttachDereferenceableMetadataForLoad(llvm::LoadInst* load,
int64 buffer_size) {
if (buffer_size > 0) {
llvm_ir::SetDereferenceableMetadataForLoad(load, buffer_size);
}
}
Status IrEmitter::HandleGetTupleElement(HloInstruction* get_tuple_element,
HloInstruction* operand) {
// A tuple is an array of pointers, one for each operand. Each pointer points
// to the output buffer of its corresponding operand. A GetTupleElement
// instruction forwards a pointer to the tuple element buffer at the given
// index.
const Shape& shape = get_tuple_element->shape();
emitted_value_[get_tuple_element] = llvm_ir::EmitGetTupleElement(
shape, get_tuple_element->tuple_index(), MinimumAlignmentForShape(shape),
GetEmittedValueFor(operand), &ir_builder_);
return Status::OK();
}
Status IrEmitter::HandleSelect(HloInstruction* select, HloInstruction* pred,
HloInstruction* on_true,
HloInstruction* on_false) {
TF_RET_CHECK(pred->shape().element_type() == PRED);
if (ShapeUtil::IsTuple(select->shape())) {
TF_ASSIGN_OR_RETURN(llvm::Value * output_address,
EmitTargetAddressForOp(select));
llvm_ir::EmitTupleSelect(llvm_ir::IrArray(output_address, select->shape()),
GetIrArrayForOp(pred), GetEmittedValueFor(on_true),
GetEmittedValueFor(on_false), &ir_builder_);
emitted_value_[select] = output_address;
return Status::OK();
}
return DefaultAction(select);
}
Status IrEmitter::HandleInfeed(HloInstruction* infeed) {
VLOG(2) << "HandleInfeed: " << infeed->ToString();
const Shape& shape = infeed->shape();
// The infeed operation produces data (dequeued from the infeed queue) at this
// address, which has been provided by buffer assignment.
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(infeed));
if (ShapeUtil::IsTuple(shape)) {
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(shape));
// For a tuple, we first copy each of the internal elements to
// their corresponding target locations. We then construct the
// tuple outer buffer containing pointers to the internal
// elements.
std::vector<llvm::Value*> tuple_element_addresses;
for (int64 i = 0; i < shape.tuple_shapes_size(); ++i) {
TF_ASSIGN_OR_RETURN(BufferAllocation::Slice buffer,
assignment_.GetUniqueSlice(infeed, {i}));
const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(shape, i);
// Only the outer tuple buffer's target address is obtained from
// EmitTargetAddressForOp to handle the case when Infeed is the
// root instruction. Target addresses for internal elements can
// be obtained from EmitTempBufferPointer.
llvm::Value* tuple_element_address =
EmitTempBufferPointer(buffer, tuple_element_shape);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(
XfeedKind::kInfeed, tuple_element_shape, tuple_element_address));
tuple_element_addresses.push_back(tuple_element_address);
}
llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, shape),
tuple_element_addresses, &ir_builder_);
} else {
TF_RETURN_IF_ERROR(
EmitXfeedTransfer(XfeedKind::kInfeed, shape, target_address));
}
emitted_value_[infeed] = target_address;
return Status::OK();
}
Status IrEmitter::EmitXfeedTransfer(XfeedKind kind, const Shape& shape,
llvm::Value* program_buffer_address) {
int64 length = ByteSizeOf(shape);
if (length <= 0 || length > std::numeric_limits<int32>::max()) {
return InvalidArgument(
"xfeed (infeed or outfeed) buffer length %lld is outside the valid "
"size range",
length);
}
int32 length_32 = static_cast<int32>(length);
int32 shape_length;
TF_ASSIGN_OR_RETURN(llvm::Value * shape_ptr,
llvm_ir::EncodeSelfDescribingShapeConstant(
shape, &shape_length, &ir_builder_));
// The signature of the acquire infeed buffer function is:
//
// (void*)(int32 length);
llvm::Type* int32_type = ir_builder_.getInt32Ty();
llvm::Type* i8_ptr_type = llvm::Type::getInt8PtrTy(module_->getContext());
llvm::FunctionType* acquire_type = llvm::FunctionType::get(
i8_ptr_type, {int32_type, i8_ptr_type, int32_type},
/*isVarArg=*/false);
llvm::Function* acquire_func;
if (kind == XfeedKind::kInfeed) {
acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
runtime::kAcquireInfeedBufferForDequeueSymbolName, acquire_type));
} else {
acquire_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
runtime::kAcquireOutfeedBufferForPopulationSymbolName, acquire_type));
}
acquire_func->setCallingConv(llvm::CallingConv::C);
// The signature of the release infeed buffer function is:
//
// (void)(int32 length, void* buffer);
llvm::FunctionType* release_type = llvm::FunctionType::get(
ir_builder_.getVoidTy(),
{int32_type, i8_ptr_type, i8_ptr_type, int32_type},
/*isVarArg=*/false);
llvm::Function* release_func;
if (kind == XfeedKind::kInfeed) {
release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
runtime::kReleaseInfeedBufferAfterDequeueSymbolName, release_type));
} else {
release_func = llvm::cast<llvm::Function>(module_->getOrInsertFunction(
runtime::kReleaseOutfeedBufferAfterPopulationSymbolName, release_type));
}
release_func->setCallingConv(llvm::CallingConv::C);
// Implementation note: this call informs the runtime that it wants a buffer
// of size exactly 'length_32', and the runtime is responsible for
// check-failing the process if there is a mismatch, versus passing us back a
// buffer that we might overrun.
llvm::Value* acquired_pointer = ir_builder_.CreateCall(
acquire_func, {ir_builder_.getInt32(length_32), shape_ptr,
ir_builder_.getInt32(shape_length)});
if (kind == XfeedKind::kInfeed) {
// Copy to the program buffer address from the acquired buffer.
ir_builder_.CreateMemCpy(program_buffer_address, acquired_pointer,
length_32, 1);
} else {
// Outfeed -- copy from the in-program address to the acquired buffer.
ir_builder_.CreateMemCpy(acquired_pointer, program_buffer_address,
length_32, 1);
}
ir_builder_.CreateCall(release_func,
{ir_builder_.getInt32(length_32), acquired_pointer,
shape_ptr, ir_builder_.getInt32(shape_length)});
return Status::OK();
}
Status IrEmitter::HandleOutfeed(HloInstruction* outfeed) {
HloInstruction* operand = outfeed->operands()[0];
const Shape& operand_shape = operand->shape();
llvm::Value* value = GetEmittedValueFor(operand);
if (!ShapeUtil::IsTuple(operand_shape)) {
return EmitXfeedTransfer(XfeedKind::kOutfeed, operand_shape, value);
}
TF_RET_CHECK(!ShapeUtil::IsNestedTuple(operand_shape));
for (int64 i = 0; i < operand_shape.tuple_shapes_size(); ++i) {
const Shape& tuple_element_shape =
ShapeUtil::GetTupleElementShape(operand_shape, i);
llvm::Value* tuple_element = llvm_ir::EmitGetTupleElement(
tuple_element_shape, i, MinimumAlignmentForShape(tuple_element_shape),
value, &ir_builder_);
TF_RETURN_IF_ERROR(EmitXfeedTransfer(XfeedKind::kOutfeed,
tuple_element_shape, tuple_element));
}
return Status::OK();
}
Status IrEmitter::HandleSort(HloInstruction* sort, HloInstruction* operand) {
// TODO(b/26783907): Implement sort on CPU.
return Unimplemented("Sort is not supported on CPU (b/26783907).");
}
Status IrEmitter::HandleTuple(
HloInstruction* tuple,
tensorflow::gtl::ArraySlice<HloInstruction*> operands) {
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(tuple));
std::vector<llvm::Value*> base_ptrs;
for (auto operand : operands) {
base_ptrs.push_back(GetEmittedValueFor(operand));
}
llvm_ir::EmitTuple(llvm_ir::IrArray(target_address, tuple->shape()),
base_ptrs, &ir_builder_);
emitted_value_[tuple] = target_address;
return Status::OK();
}
Status IrEmitter::HandleMap(
HloInstruction* map, tensorflow::gtl::ArraySlice<HloInstruction*> operands,
HloComputation* function,
tensorflow::gtl::ArraySlice<HloInstruction*> /*static_operands*/) {
// The called computation should have been emitted previously.
llvm::Function* mapped_ir_function = FindOrDie(emitted_functions_, function);
return EmitTargetElementLoop(map, [this, map, operands, mapped_ir_function](
const llvm_ir::IrArray::Index& index) {
std::vector<llvm::Value*> parameter_addresses;
for (const HloInstruction* operand : operands) {
const llvm_ir::IrArray& array = GetIrArrayForOp(operand);
parameter_addresses.push_back(
array.EmitArrayElementAddress(index, &ir_builder_));
}
return EmitElementFunctionCall(mapped_ir_function, map->shape(),
parameter_addresses, "map_function");
});
}
Status IrEmitter::HandleReduceWindow(HloInstruction* reduce_window,
HloInstruction* operand,
const Window& window,
HloComputation* function) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*reduce_window, /*operands=*/{operand},
/*supported_types=*/{F32}));
// TODO(b/31410564): Implement dilation for reduce-window.
if (window_util::HasDilation(window)) {
return Unimplemented(
"Dilation for reduce-window not implemented on CPU. See b/31410564.");
}
// The called computation should have been emitted previously.
llvm::Function* reducer_function = FindOrDie(emitted_functions_, function);
// Pseudo code for reduce window:
//
// for (coordinates O in the output)
// value = init_value;
// for (coordinates W in the window)
// for each index i:
// input coordinates I_i = O_i * stride_i + W_i - pad_low_i
// if I within bounds of input:
// value = function(value, input(I));
// output(O) = value;
//
// This is completely un-optimized and just here to have something
// that works.
return EmitTargetElementLoop(
reduce_window, [this, reduce_window, operand, window,
reducer_function](const llvm_ir::IrArray::Index& index) {
// We fold inputs into the accumulator and initialize it to
// the initial value on the reduce_window.
PrimitiveType operand_element_type = operand->shape().element_type();
llvm::Value* accumulator_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
"reduce_window_accumulator_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(operand_element_type));
ir_builder_.CreateStore(ir_builder_.CreateLoad(GetEmittedValueFor(
reduce_window->operand(1))),
accumulator_address);
llvm_ir::ForLoopNest loops(&ir_builder_);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
}
const llvm_ir::IrArray::Index window_index = loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
CHECK_EQ(window_index.size(), index.size());
SetToFirstInsertPoint(loops.GetInnerLoopBodyBasicBlock(), &ir_builder_);
llvm_ir::IrArray::Index input_index(index.size());
llvm::Value* in_bounds_condition = nullptr;
for (size_t i = 0; i < index.size(); ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
input_index[i] = ir_builder_.CreateNSWSub(
ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
ir_builder_.getInt64(window.dimensions(i).padding_low()));
// We need to check if 0 <= input_index[i] < bound, as
// otherwise we are in the padding so that we can skip the
// computation. That is equivalent to input_index[i] < bound
// as an *unsigned* comparison, since a negative value will
// wrap to a large positive value.
llvm::Value* index_condition = ir_builder_.CreateICmpULT(
input_index[i], ir_builder_.getInt64(ShapeUtil::GetDimension(
operand->shape(), i)));
if (in_bounds_condition == nullptr) {
in_bounds_condition = index_condition;
} else {
in_bounds_condition =
ir_builder_.CreateAnd(in_bounds_condition, index_condition);
}
}
CHECK(in_bounds_condition != nullptr);
llvm_ir::LlvmIfData if_data = llvm_ir::EmitIfThenElse(
in_bounds_condition, "in-bounds", &ir_builder_);
SetToFirstInsertPoint(if_data.true_block, &ir_builder_);
// We are not in the padding, so carry out the computation.
llvm_ir::IrArray input_array(GetIrArrayForOp(operand));
llvm::Value* input_value_address =
input_array.EmitArrayElementAddress(input_index, &ir_builder_);
llvm::Value* result = EmitElementFunctionCall(
reducer_function, reduce_window->shape(),
{accumulator_address, input_value_address}, "reducer_function");
ir_builder_.CreateStore(result, accumulator_address);
SetToFirstInsertPoint(loops.GetOuterLoopExitBasicBlock(), &ir_builder_);
return ir_builder_.CreateLoad(accumulator_address);
});
}
Status IrEmitter::HandleSelectAndScatter(HloInstruction* select_and_scatter) {
CHECK_EQ(select_and_scatter->operand_count(), 3);
const auto operand = select_and_scatter->operand(0);
const auto source = select_and_scatter->operand(1);
const auto init_value = select_and_scatter->operand(2);
const Window& window = select_and_scatter->window();
PrimitiveType operand_element_type = operand->shape().element_type();
const int64 rank = ShapeUtil::Rank(operand->shape());
CHECK_EQ(rank, ShapeUtil::Rank(source->shape()));
CHECK_EQ(rank, window.dimensions_size());
// TODO(b/31410564): Implement dilation for select-and-scatter.
if (window_util::HasDilation(window)) {
return Unimplemented(
"Dilation for select-and-scatter not implemented on CPU. "
"See b/31410564.");
}
// The select and scatter computations should have been emitted previously.
llvm::Function* select_function =
FindOrDie(emitted_functions_, select_and_scatter->select());
llvm::Function* scatter_function =
FindOrDie(emitted_functions_, select_and_scatter->scatter());
// Pseudo code for select-and-scatter:
//
// initialized_flag is initially off for every window, and is turned on after
// the first iteration is completed and the first operand value is selected.
//
// output(*) = init_value
// for (coordinates S in the source) {
// initialized_flag = false
// for (coordinates W in the window) {
// I = S * stride + W - pad_low
// if I within bounds of operand:
// if !initialized_flag or select(selected_value, operand(I)) == false:
// selected_value = operand(I)
// selected_index = I
// initialized_flag = true
// }
// output(selected_index) = scatter(output(selected_index), source(S))
// }
//
// Initialize the output array with the given init_value.
TF_RETURN_IF_ERROR(EmitTargetElementLoop(
select_and_scatter,
[this, init_value](const llvm_ir::IrArray::Index& target_index) {
llvm::Value* init_value_addr = GetEmittedValueFor(init_value);
return ir_builder_.CreateLoad(init_value_addr);
}));
// Create a loop to iterate over the source array to scatter to the output.
llvm_ir::ForLoopNest source_loops(&ir_builder_);
const llvm_ir::IrArray::Index source_index =
source_loops.AddLoopsForShape(source->shape(), "source");
SetToFirstInsertPoint(source_loops.GetInnerLoopBodyBasicBlock(),
&ir_builder_);
// Allocate space to keep the currently selected value, its index, and
// the boolean initialized_flag, which is initially set to false.
llvm::Value* selected_value_address = llvm_ir::EmitAllocaAtFunctionEntry(
llvm_ir::PrimitiveTypeToIrType(operand_element_type, &ir_builder_),
"selected_value_address", &ir_builder_,
MinimumAlignmentForPrimitiveType(operand_element_type));
llvm::Value* selected_index_address =
llvm_ir::EmitAllocaAtFunctionEntryWithCount(
ir_builder_.getInt64Ty(), ir_builder_.getInt32(rank),
"selected_index_address", &ir_builder_);
llvm::Value* initialized_flag_address = llvm_ir::EmitAllocaAtFunctionEntry(
ir_builder_.getInt1Ty(), "initialized_flag_address", &ir_builder_);
ir_builder_.CreateStore(ir_builder_.getInt1(false), initialized_flag_address);
// Create the inner loop to iterate over the window.
llvm_ir::ForLoopNest window_loops(&ir_builder_);
std::vector<int64> window_size;
for (const auto& dim : window.dimensions()) {
window_size.push_back(dim.size());
}
const llvm_ir::IrArray::Index window_index = window_loops.AddLoopsForShape(
ShapeUtil::MakeShape(operand_element_type, window_size), "window");
SetToFirstInsertPoint(window_loops.GetInnerLoopBodyBasicBlock(),
&ir_builder_);
// Compute the operand index to visit and evaluate the condition whether the
// operand index is within the bounds. The unsigned comparison includes
// checking whether the operand index >= 0.
llvm_ir::IrArray::Index operand_index(source_index.size());
llvm::Value* in_bounds_condition = ir_builder_.getInt1(true);
for (int64 i = 0; i < rank; ++i) {
llvm::Value* strided_index = ir_builder_.CreateNSWMul(
source_index[i], ir_builder_.getInt64(window.dimensions(i).stride()));
operand_index[i] = ir_builder_.CreateNSWSub(
ir_builder_.CreateNSWAdd(strided_index, window_index[i]),
ir_builder_.getInt64(window.dimensions(i).padding_low()));
llvm::Value* index_condition = ir_builder_.CreateICmpULT(
operand_index[i],
ir_builder_.getInt64(ShapeUtil::GetDimension(operand->shape(), i)));
in_bounds_condition =
ir_builder_.CreateAnd(in_bounds_condition, index_condition);
}
CHECK(in_bounds_condition != nullptr);
// Only need to do something if the operand index is within the bounds. First
// check if the initialized_flag is set.
llvm_ir::LlvmIfData if_in_bounds =
llvm_ir::EmitIfThenElse(in_bounds_condition, "in-bounds", &ir_builder_);
SetToFirstInsertPoint(if_in_bounds.true_block, &ir_builder_);
llvm_ir::LlvmIfData if_initialized =
llvm_ir::EmitIfThenElse(ir_builder_.CreateLoad(initialized_flag_address),
"initialized", &ir_builder_);
// If the initialized_flag is false, initialize the selected value and index
// with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.false_block, &ir_builder_);
const auto save_operand_index = [&](
const llvm_ir::IrArray::Index& operand_index) {
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
ir_builder_.CreateStore(operand_index[i], selected_index_address_slot);
}
};
llvm_ir::IrArray operand_array(GetIrArrayForOp(operand));
llvm::Value* operand_data =
operand_array.EmitReadArrayElement(operand_index, &ir_builder_);
ir_builder_.CreateStore(operand_data, selected_value_address);
save_operand_index(operand_index);
ir_builder_.CreateStore(ir_builder_.getInt1(true), initialized_flag_address);
// If the initialized_flag is true, call the `select` function to potentially
// update the selected value and index with the currently visiting operand.
SetToFirstInsertPoint(if_initialized.true_block, &ir_builder_);
const Shape output_shape = ShapeUtil::MakeShape(PRED, {});
llvm::Value* operand_address =
operand_array.EmitArrayElementAddress(operand_index, &ir_builder_);
llvm::Value* result = EmitElementFunctionCall(
select_function, output_shape, {selected_value_address, operand_address},
"select_function");
// If the 'select' function returns false, update the selected value and the
// index to the currently visiting operand.
llvm::Value* cond = ir_builder_.CreateICmpNE(
result, llvm::ConstantInt::get(
llvm_ir::PrimitiveTypeToIrType(PRED, &ir_builder_), 0),
"boolean_predicate");
llvm_ir::LlvmIfData if_select_lhs =
llvm_ir::EmitIfThenElse(cond, "if-select-lhs", &ir_builder_);
SetToFirstInsertPoint(if_select_lhs.false_block, &ir_builder_);
ir_builder_.CreateStore(ir_builder_.CreateLoad(operand_address),
selected_value_address);
save_operand_index(operand_index);
// After iterating over the window elements, scatter the source element to
// the selected index of the output. The value we store at the output
// location is computed by calling the `scatter` function with the source
// value and the current output value.
SetToFirstInsertPoint(window_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
llvm_ir::IrArray::Index selected_index;
for (int64 i = 0; i < rank; ++i) {
llvm::Value* selected_index_address_slot = ir_builder_.CreateInBoundsGEP(
selected_index_address, {ir_builder_.getInt32(i)});
selected_index.push_back(
ir_builder_.CreateLoad(selected_index_address_slot));
}
llvm_ir::IrArray source_array(GetIrArrayForOp(source));
llvm::Value* source_value_address =
source_array.EmitArrayElementAddress(source_index, &ir_builder_);
llvm_ir::IrArray output_array(GetIrArrayForOp(select_and_scatter));
llvm::Value* output_value_address =
output_array.EmitArrayElementAddress(selected_index, &ir_builder_);
llvm::Value* scatter_value = EmitElementFunctionCall(
scatter_function, source->shape(),
{output_value_address, source_value_address}, "scatter_function");
output_array.EmitWriteArrayElement(selected_index, scatter_value,
&ir_builder_);
SetToFirstInsertPoint(source_loops.GetOuterLoopExitBasicBlock(),
&ir_builder_);
return Status::OK();
}
Status IrEmitter::HandleDot(HloInstruction* dot, HloInstruction* lhs,
HloInstruction* rhs) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*dot, /*operands=*/{lhs, rhs},
/*supported_types=*/{F32, F64}));
llvm_ir::IrArray lhs_array(GetIrArrayForOp(lhs));
llvm_ir::IrArray rhs_array(GetIrArrayForOp(rhs));
Shape target_shape = dot->shape();
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(dot));
llvm_ir::IrArray target_array(target_address, target_shape);
AddAliasingInformationToIrArray(*dot, &target_array);
VLOG(2) << "HandleDot: ";
VLOG(2) << " lhs operand: "
<< llvm_ir::DumpToString(*lhs_array.GetBasePointer());
VLOG(2) << " rhs operand: "
<< llvm_ir::DumpToString(*rhs_array.GetBasePointer());
VLOG(2) << " target: "
<< llvm_ir::DumpToString(*target_array.GetBasePointer());
// Dot operation is complicated so we delegate to a helper class.
TF_RETURN_IF_ERROR(DotOpEmitter::EmitDotOperation(
*dot, /*transpose_lhs=*/false, /*transpose_rhs=*/false, target_array,
lhs_array, rhs_array, GetExecutableRunOptionsArgument(), &ir_builder_,
hlo_module_config_));
emitted_value_[dot] = target_address;
return Status::OK();
}
Status IrEmitter::HandleConvolution(HloInstruction* convolution,
HloInstruction* lhs, HloInstruction* rhs,
const Window& window) {
TF_RETURN_IF_ERROR(ElementTypesSameAndSupported(
/*instruction=*/*convolution, /*operands=*/{lhs, rhs},
/*supported_types=*/{F32}));
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
if (PotentiallyImplementedAsEigenConvolution(*convolution)) {
const Shape& lhs_shape = lhs->shape();
const Shape& rhs_shape = rhs->shape();
const Shape& convolution_shape = convolution->shape();
// The input, kernel and output agree with respect to layout.
if (LayoutUtil::IsMonotonicWithDim0Major(lhs_shape.layout()) &&
LayoutUtil::IsMonotonicWithDim0Major(rhs_shape.layout()) &&
LayoutUtil::IsMonotonicWithDim0Major(convolution_shape.layout())) {
llvm::Value* lhs_address = GetEmittedValueFor(lhs);
llvm::Value* rhs_address = GetEmittedValueFor(rhs);
TF_ASSIGN_OR_RETURN(llvm::Value * target_address,
EmitTargetAddressForOp(convolution));
const ConvolutionDimensionNumbers& dnums =
convolution->convolution_dimension_numbers();
// Input tensor.
const Shape& input_shape = convolution->operand(0)->shape();
int64 input_batch = input_shape.dimensions(dnums.batch_dimension());
int64 input_rows = input_shape.dimensions(dnums.spatial_dimensions(0));
int64 input_cols = input_shape.dimensions(dnums.spatial_dimensions(1));
int64 input_channels = input_shape.dimensions(dnums.feature_dimension());
// Kernel tensor.
const Shape& kernel_shape = convolution->operand(1)->shape();
int64 kernel_rows =
kernel_shape.dimensions(dnums.kernel_spatial_dimensions(0));
int64 kernel_cols =
kernel_shape.dimensions(dnums.kernel_spatial_dimensions(1));
int64 kernel_channels =
kernel_shape.dimensions(dnums.kernel_input_feature_dimension());
int64 kernel_filters =
kernel_shape.dimensions(dnums.kernel_output_feature_dimension());
// Output tensor.
const Shape& convolution_shape = convolution->shape();
int64 output_rows =
convolution_shape.dimensions(dnums.spatial_dimensions(0));
int64 output_cols =
convolution_shape.dimensions(dnums.spatial_dimensions(1));
// Extract the window stride for the convolution.
const Window& window = convolution->window();
int64 row_stride = window.dimensions(0).stride();
int64 col_stride = window.dimensions(1).stride();
int64 padding_top = window.dimensions(0).padding_low();
int64 padding_bottom = window.dimensions(0).padding_high();
int64 padding_left = window.dimensions(1).padding_low();
int64 padding_right = window.dimensions(1).padding_high();
int64 lhs_row_dilation = window.dimensions(0).base_dilation();
int64 lhs_col_dilation = window.dimensions(1).base_dilation();
int64 rhs_row_dilation = window.dimensions(0).window_dilation();
int64 rhs_col_dilation = window.dimensions(1).window_dilation();
// Args have been computed, make the call.
llvm::Type* float_ptr_type = ir_builder_.getFloatTy()->getPointerTo();
llvm::Type* int64_type = ir_builder_.getInt64Ty();
llvm::Type* int8_ptr_type = ir_builder_.getInt8Ty()->getPointerTo();
llvm::FunctionType* conv_type = llvm::FunctionType::get(
ir_builder_.getVoidTy(),
{int8_ptr_type, float_ptr_type, float_ptr_type, float_ptr_type,
int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type,
int64_type, int64_type, int64_type, int64_type},
/*isVarArg=*/false);
bool multi_threaded_eigen =
hlo_module_config_.debug_options().xla_cpu_multi_thread_eigen();
const char* fn_name =
(multi_threaded_eigen
? runtime::kEigenConvF32SymbolName
: runtime::kEigenSingleThreadedConvF32SymbolName);
llvm::Function* conv_func = llvm::cast<llvm::Function>(
module_->getOrInsertFunction(fn_name, conv_type));
conv_func->setCallingConv(llvm::CallingConv::C);
conv_func->setDoesNotThrow();
conv_func->setOnlyAccessesArgMemory();
ir_builder_.CreateCall(
conv_func,
{
GetExecutableRunOptionsArgument(),
ir_builder_.CreateBitCast(target_address, float_ptr_type),
ir_builder_.CreateBitCast(lhs_address, float_ptr_type),
ir_builder_.CreateBitCast(rhs_address, float_ptr_type),
ir_builder_.getInt64(input_batch),
ir_builder_.getInt64(input_rows),
ir_builder_.getInt64(input_cols),
ir_builder_.getInt64(input_channels),
ir_builder_.getInt64(kernel_rows),
ir_builder_.getInt64(kernel_cols),
ir_builder_.getInt64(kernel_channels),
ir_builder_.getInt64(kernel_filters),
ir_builder_.getInt64(output_rows),
ir_builder_.getInt64(output_cols),
ir_builder_.getInt64(row_stride),
ir_builder_.getInt64(col_stride),
ir_builder_.getInt64(padding_top),
ir_builder_.getInt64(padding_bottom),
ir_builder_.getInt64(padding_left),
ir_builder_.getInt64(padding_right),
ir_builder_.getInt64(lhs_row_dilation),
ir_builder_.getInt64(lhs_col_dilation),
ir_builder_.getInt64(rhs_row_dilation),
ir_builder_.getInt64(rhs_col_dilation),
});
emitted_value_[convolution] = target_address;