Skip to content

Commit

Permalink
[xla:spmd] Fix a problem in generating code for concat that requires …
Browse files Browse the repository at this point in the history
…padding.

When an internal concat operand requires padding, previously, we don't zero out
the padded elements and the generated code produces incorrect results.

Fix two existing tests.

PiperOrigin-RevId: 634066344
  • Loading branch information
bixia1 authored and Copybara-Service committed May 15, 2024
1 parent ee7382e commit a9e6e73
Show file tree
Hide file tree
Showing 2 changed files with 12 additions and 3 deletions.
3 changes: 2 additions & 1 deletion xla/service/spmd/spmd_partitioner.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2581,7 +2581,8 @@ Status SpmdPartitioningVisitor::HandleConcatenate(HloInstruction* hlo) {
int64_t offset = 0;
auto state = MakePartitioningState();
for (HloInstruction* operand : hlo->operands()) {
auto spmd_operand = GetPartitionedHlo(operand).Reshard(sharding).hlo();
auto spmd_operand =
GetPartitionedHlo(operand).Reshard(sharding).PadWithZero().hlo();
std::vector<HloInstruction*> start_indices(
hlo->shape().rank(), b_.AddInstruction(HloInstruction::CreateConstant(
LiteralUtil::Zero(S32))));
Expand Down
12 changes: 10 additions & 2 deletions xla/service/spmd/spmd_partitioner_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -2207,13 +2207,17 @@ ENTRY entry {
AllOf(op::Copy(op::DynamicSlice(op::Pad(op::Parameter(), op::Constant()),
op::Constant(), op::Reshape())),
op::Shape("f32[14,129]"));
auto param0_adjusted =
AllOf(op::Select(op::Compare(op::Add(), op::Broadcast(op::Constant())),
param0, op::Broadcast(op::Constant())),
op::Shape("f32[14,129]"));
auto param1 = AllOf(op::Copy(op::DynamicSlice(op::Parameter(), op::Constant(),
op::Reshape())),
op::Shape("f32[14,58]"));
EXPECT_THAT(root, AllOf(op::DynamicSlice(
AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::DynamicUpdateSlice(
op::Broadcast(), param0,
op::Broadcast(), param0_adjusted,
op::Constant(), op::Multiply()),
param1, op::Constant(), op::Add())),
op::Shape("f32[14,374]")),
Expand All @@ -2238,11 +2242,15 @@ ENTRY entry {

const auto root = module->entry_computation()->root_instruction();
auto param0 = AllOf(op::Parameter(0), op::Shape("f32[7,129]"));
auto param0_adjusted =
AllOf(op::Select(op::Compare(op::Add(), op::Broadcast(op::Constant())),
param0, op::Broadcast(op::Constant())),
op::Shape("f32[7,129]"));
auto param1 = AllOf(op::Parameter(1), op::Shape("f32[7,58]"));
EXPECT_THAT(root, AllOf(op::DynamicSlice(
AllOf(op::AllReduce(op::DynamicUpdateSlice(
op::DynamicUpdateSlice(
op::Broadcast(), param0,
op::Broadcast(), param0_adjusted,
op::Constant(), op::Multiply()),
param1, op::Constant(), op::Add())),
op::Shape("f32[7,374]")),
Expand Down

0 comments on commit a9e6e73

Please sign in to comment.