Skip to content

Commit

Permalink
[XLA:TPU] HostOffloader: Refactor HandleInputStreaming to occur befor…
Browse files Browse the repository at this point in the history
…e tracing the graph for custom calls.

Currently, input streaming gets handled in arbitrary order, after we have already possibly searched the graph for the corresponding streamed custom calls. This refactors the input streaming to occur before any graph searching, which makes more sense and prepares for output streaming.

PiperOrigin-RevId: 629208199
  • Loading branch information
jvstokes authored and tensorflower-gardener committed May 9, 2024
1 parent 3b3e62e commit 9485229
Show file tree
Hide file tree
Showing 14 changed files with 87 additions and 30 deletions.
13 changes: 9 additions & 4 deletions tensorflow/core/data/service/client/data_service_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/data/utils.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/platform/env.h"
Expand Down Expand Up @@ -102,7 +103,10 @@ DataServiceClient::~DataServiceClient() {
<< iteration_client_id_;
}

Status DataServiceClient::Initialize(Allocator* allocator) {
Status DataServiceClient::Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
accelerator_device_info_ = accelerator_device_info;
allocator_ = allocator;
TF_RETURN_IF_ERROR(ValidateDataServiceParams(params_));
VLOG(3) << "Connecting to " << params_.address
Expand Down Expand Up @@ -343,7 +347,7 @@ DataServiceClient::CreateWorkerClient(const std::string& protocol,
TF_ASSIGN_OR_RETURN(DataTransferServerInfo transfer_server,
GetTransferServer(protocol, task_info));
return CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
}

absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>>
Expand All @@ -356,7 +360,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback(
const DataTransferServerInfo& transfer_server, const TaskInfo& task_info) {
absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>> worker =
CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
if (worker.ok()) {
LOG(INFO) << "Successfully started client for data transfer protocol '"
<< transfer_server.protocol() << "' for worker '"
Expand All @@ -383,7 +387,8 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) {
DataTransferServerInfo info;
info.set_protocol(kLocalTransferProtocol);
info.set_address(task_info.worker_address());
return CreateDataServiceWorkerClient(params_.protocol, info, allocator_);
return CreateDataServiceWorkerClient(params_.protocol, info,
accelerator_device_info_, allocator_);
}
if (!params_.data_transfer_protocol.empty()) {
TF_ASSIGN_OR_RETURN(
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/data/service/client/data_service_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class DataServiceClient {
DataServiceClient& operator=(const DataServiceClient&) = delete;

// Initializes the client.
Status Initialize(Allocator* allocator);
Status Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

// Reads the next element from tf.data workers. Blocks if the next element is
// not ready.
Expand Down Expand Up @@ -246,6 +248,7 @@ class DataServiceClient {
int64_t job_id_;
int64_t iteration_client_id_;
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0;
Expand Down
18 changes: 12 additions & 6 deletions tensorflow/core/data/service/client/data_service_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ TEST(DataServiceClientTest, NoSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(ElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -150,7 +151,8 @@ TEST(DataServiceClientTest, DynamicSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::DYNAMIC);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -167,7 +169,8 @@ TEST(DataServiceClientTest, StaticSharding) {
GetDataServiceParams(dataset_id, test_cluster.DispatcherAddress(),
ProcessingModeDef::FILE_OR_DATA);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -183,7 +186,8 @@ TEST(DataServiceClientTest, RecordBufferEvents) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));

auto mock_context = std::make_unique<TestDataServiceContext>();
TestDataServiceContext* ctx = mock_context.get();
Expand All @@ -206,7 +210,8 @@ TEST(DataServiceClientTest, Cancel) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
client.Cancel();
EXPECT_THAT(client.GetNext(GetTestDataServiceContext),
StatusIs(error::CANCELLED));
Expand All @@ -218,7 +223,8 @@ TEST(DataServiceClientTest, ValidationError) {
params.target_workers = TARGET_WORKERS_LOCAL;
DataServiceClient client(params);
EXPECT_THAT(
client.Initialize(/*allocator=*/nullptr),
client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr),
StatusIs(
error::INVALID_ARGUMENT,
HasSubstr(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DataTransferClient {
struct Config {
absl::string_view protocol;
std::string address;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info;
Allocator* allocator;
};
using ClientFactoryT =
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/test_cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ DatasetClient<T>::DatasetClient(const TestCluster& cluster)
for (size_t i = 0; i < cluster.NumWorkers(); ++i) {
worker_clients_[cluster_.WorkerAddress(i)] =
std::make_unique<DataServiceWorkerClient>(
cluster_.WorkerAddress(i), "grpc", "grpc", /*allocator=*/nullptr);
cluster_.WorkerAddress(i), "grpc", "grpc",
/*accelerator_device_info=*/nullptr, /*allocator=*/nullptr);
}
}

Expand Down
14 changes: 9 additions & 5 deletions tensorflow/core/data/service/worker_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/data/service/worker_impl.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -56,11 +57,13 @@ namespace tensorflow {
namespace data {

StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator) {
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
auto client = std::make_unique<DataServiceWorkerClient>(
info.address(), dispatcher_protocol, info.protocol(), allocator);
info.address(), dispatcher_protocol, info.protocol(),
accelerator_device_info, allocator);
TF_RETURN_IF_ERROR(client->Initialize());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
client->CheckCompatibility(info.compatibility_info()),
Expand All @@ -82,7 +85,8 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return absl::OkStatus();
}
TF_RETURN_IF_ERROR(DataTransferClient::Build(
GetDataTransferProtocol(), {protocol_, address_, allocator_}, &client_));
GetDataTransferProtocol(),
{protocol_, address_, accelerator_device_info_, allocator_}, &client_));
return absl::OkStatus();
}

Expand Down
18 changes: 11 additions & 7 deletions tensorflow/core/data/service/worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ constexpr const char kGrpcTransferProtocol[] = "grpc";
// Client for communicating with the tf.data service worker.
class DataServiceWorkerClient : public DataServiceClientBase {
public:
DataServiceWorkerClient(const std::string& address,
const std::string& protocol,
const std::string& transfer_protocol,
Allocator* allocator)
DataServiceWorkerClient(
const std::string& address, const std::string& protocol,
const std::string& transfer_protocol,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator)
: DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol),
accelerator_device_info_(accelerator_device_info),
allocator_(allocator) {}

// Fetches an element from the worker.
Expand All @@ -66,6 +68,7 @@ class DataServiceWorkerClient : public DataServiceClientBase {

private:
std::string transfer_protocol_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

mutex mu_;
Expand All @@ -77,9 +80,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Creates and initializes a new tf.data service worker client to read
// from the data transfer server specified in `info`.
StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator);
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

} // namespace data
} // namespace tensorflow
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/worker_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class WorkerClientTest : public ::testing::Test {
info.set_address(GetWorkerAddress());
info.set_protocol(data_transfer_protocol);
return CreateDataServiceWorkerClient(kProtocol, info,
/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr);
}

Expand Down
11 changes: 10 additions & 1 deletion tensorflow/core/framework/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ class IteratorContext {
public:
struct Params {
explicit Params(IteratorContext* ctx)
: allocator_getter(ctx->allocator_getter()),
: accelerator_device_info(ctx->accelerator_device_info()),
allocator_getter(ctx->allocator_getter()),
cancellation_manager(ctx->cancellation_manager()),
collective_executor(ctx->collective_executor()),
env(ctx->env()),
Expand Down Expand Up @@ -697,6 +698,7 @@ class IteratorContext {
// NOTE: need reinterpret_cast because function.h forward-declares Device.
DeviceBase* device =
reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
accelerator_device_info = device->tensorflow_accelerator_device_info();
allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
Expand All @@ -719,6 +721,9 @@ class IteratorContext {
*ctx->runner(), std::placeholders::_1);
}

// If non-null, information about the GPU or TPU on which the op is placed.
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = nullptr;

// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;

Expand Down Expand Up @@ -825,6 +830,10 @@ class IteratorContext {
return params_.id_registry;
}

const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info() {
return params_.accelerator_device_info;
}

Allocator* allocator(AllocatorAttributes attrs) {
return params_.allocator_getter(attrs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { data_service_client_.Cancel(); }, &deregister_fn_));
return data_service_client_.Initialize(ctx->allocator(/*attrs=*/{}));
return data_service_client_.Initialize(ctx->accelerator_device_info(),
ctx->allocator(/*attrs=*/{}));
}

Status GetNextInternal(IteratorContext* ctx,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/profiler/protobuf/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ option cc_enable_arenas = true;
// 'Task' contains information about a task that profiler traced.
message Task {
// The most recent changelist number from the client that built the binary.
optional int32 changelist = 1;
optional int64 changelist = 1;
// True if the client that built the binary was mint (no local changes).
optional bool clean_build = 2;
// Build time (in ns relative to the Unix epoch).
Expand Down
11 changes: 11 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
return success();
}

@@ -3861,8 +3819,8 @@
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
- int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t shapeCount = std::accumulate(
+ shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
return success();
}

@@ -3861,8 +3819,8 @@
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
- int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t shapeCount = std::accumulate(
+ shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,12 @@ absl::StatusOr<bool> HostOffloader::Run(
// Run HloAliasAnalysis on module.
TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module));

// Handle streamed parameters first.
TF_RETURN_IF_ERROR(HandleInputStreaming(module->entry_computation()));

// Iterate over all instructions and look for XLA host offload annotations.
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (computation->IsEntryComputation()) {
TF_RETURN_IF_ERROR(HandleInputStreaming(computation));
}
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
Expand Down

0 comments on commit 9485229

Please sign in to comment.