Skip to content

Commit

Permalink
[XLA:TPU] HostOffloader: Refactor HandleInputStreaming to occur befor…
Browse files Browse the repository at this point in the history
…e tracing the graph for custom calls.

Currently, input streaming gets handled in arbitrary order, after we have already possibly searched the graph for the corresponding streamed custom calls. This refactors the input streaming to occur before any graph searching, which makes more sense and prepares for output streaming.

PiperOrigin-RevId: 629208199
  • Loading branch information
jvstokes authored and tensorflower-gardener committed Apr 29, 2024
1 parent b1dce7f commit 0a5b586
Show file tree
Hide file tree
Showing 2 changed files with 8 additions and 5 deletions.
7 changes: 5 additions & 2 deletions third_party/xla/xla/python/py_compile_only_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -60,6 +60,7 @@ limitations under the License.
#include "xla/python/ifrt/tuple.h"
#include "xla/python/ifrt/value.h"
#include "xla/python/nb_class_ptr.h"
#include "xla/python/pjrt_ifrt/pjrt_array.h"
#include "xla/python/py_client.h"
#include "xla/service/computation_placer.h"
#include "xla/tsl/concurrency/ref_count.h"
Expand Down Expand Up @@ -231,8 +232,10 @@ class CompileOnlyIfRtClient final
absl::StatusOr<std::unique_ptr<PjRtLayout>> GetDefaultLayoutForDevice(
ifrt::DType dtype, absl::Span<const int64_t> dims,
ifrt::Device* device) const override {
return absl::UnimplementedError(
"GetDefaultLayout not supported for CompileOnlyIfRtClient.");
TF_ASSIGN_OR_RETURN(PrimitiveType element_type, ToPrimitiveType(dtype));
TF_ASSIGN_OR_RETURN(xla::Layout layout,
topology_->GetDefaultLayout(element_type, dims));
return std::make_unique<PjRtXlaLayout>(std::move(layout));
}

private:
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -761,12 +761,12 @@ absl::StatusOr<bool> HostOffloader::Run(
// Run HloAliasAnalysis on module.
TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module));

// Handle streamed parameters first.
TF_RETURN_IF_ERROR(HandleInputStreaming(module->entry_computation()));

// Iterate over all instructions and look for XLA host offload annotations.
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (computation->IsEntryComputation()) {
TF_RETURN_IF_ERROR(HandleInputStreaming(computation));
}
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
Expand Down

0 comments on commit 0a5b586

Please sign in to comment.