Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[XLA:TPU] HostOffloader: Refactor HandleInputStreaming to occur before tracing the graph for custom calls. #66656

Draft
wants to merge 1 commit into
base: master
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
13 changes: 9 additions & 4 deletions tensorflow/core/data/service/client/data_service_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@ limitations under the License.
#include "tensorflow/core/data/utils.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/model.h"
#include "tensorflow/core/platform/env.h"
Expand Down Expand Up @@ -102,7 +103,10 @@ DataServiceClient::~DataServiceClient() {
<< iteration_client_id_;
}

Status DataServiceClient::Initialize(Allocator* allocator) {
Status DataServiceClient::Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
accelerator_device_info_ = accelerator_device_info;
allocator_ = allocator;
TF_RETURN_IF_ERROR(ValidateDataServiceParams(params_));
VLOG(3) << "Connecting to " << params_.address
Expand Down Expand Up @@ -343,7 +347,7 @@ DataServiceClient::CreateWorkerClient(const std::string& protocol,
TF_ASSIGN_OR_RETURN(DataTransferServerInfo transfer_server,
GetTransferServer(protocol, task_info));
return CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
}

absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>>
Expand All @@ -356,7 +360,7 @@ DataServiceClient::CreateAlternativeWorkerClientWithGrpcFallback(
const DataTransferServerInfo& transfer_server, const TaskInfo& task_info) {
absl::StatusOr<std::unique_ptr<DataServiceWorkerClient>> worker =
CreateDataServiceWorkerClient(params_.protocol, transfer_server,
allocator_);
accelerator_device_info_, allocator_);
if (worker.ok()) {
LOG(INFO) << "Successfully started client for data transfer protocol '"
<< transfer_server.protocol() << "' for worker '"
Expand All @@ -383,7 +387,8 @@ DataServiceClient::CreateWorkerClient(const TaskInfo& task_info) {
DataTransferServerInfo info;
info.set_protocol(kLocalTransferProtocol);
info.set_address(task_info.worker_address());
return CreateDataServiceWorkerClient(params_.protocol, info, allocator_);
return CreateDataServiceWorkerClient(params_.protocol, info,
accelerator_device_info_, allocator_);
}
if (!params_.data_transfer_protocol.empty()) {
TF_ASSIGN_OR_RETURN(
Expand Down
5 changes: 4 additions & 1 deletion tensorflow/core/data/service/client/data_service_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,9 @@ class DataServiceClient {
DataServiceClient& operator=(const DataServiceClient&) = delete;

// Initializes the client.
Status Initialize(Allocator* allocator);
Status Initialize(
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

// Reads the next element from tf.data workers. Blocks if the next element is
// not ready.
Expand Down Expand Up @@ -246,6 +248,7 @@ class DataServiceClient {
int64_t job_id_;
int64_t iteration_client_id_;
std::unique_ptr<DataServiceDispatcherClient> dispatcher_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

int64_t get_next_index_ TF_GUARDED_BY(mu_) = 0;
Expand Down
18 changes: 12 additions & 6 deletions tensorflow/core/data/service/client/data_service_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,8 @@ TEST(DataServiceClientTest, NoSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(ElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -150,7 +151,8 @@ TEST(DataServiceClientTest, DynamicSharding) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::DYNAMIC);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -167,7 +169,8 @@ TEST(DataServiceClientTest, StaticSharding) {
GetDataServiceParams(dataset_id, test_cluster.DispatcherAddress(),
ProcessingModeDef::FILE_OR_DATA);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
EXPECT_THAT(GetResults<int64_t>(client),
IsOkAndHolds(UnorderedElementsAreArray(Range(10))));
client.Cancel();
Expand All @@ -183,7 +186,8 @@ TEST(DataServiceClientTest, RecordBufferEvents) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));

auto mock_context = std::make_unique<TestDataServiceContext>();
TestDataServiceContext* ctx = mock_context.get();
Expand All @@ -206,7 +210,8 @@ TEST(DataServiceClientTest, Cancel) {
DataServiceParams params = GetDataServiceParams(
dataset_id, test_cluster.DispatcherAddress(), ProcessingModeDef::OFF);
DataServiceClient client(params);
TF_ASSERT_OK(client.Initialize(/*allocator=*/nullptr));
TF_ASSERT_OK(client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr));
client.Cancel();
EXPECT_THAT(client.GetNext(GetTestDataServiceContext),
StatusIs(error::CANCELLED));
Expand All @@ -218,7 +223,8 @@ TEST(DataServiceClientTest, ValidationError) {
params.target_workers = TARGET_WORKERS_LOCAL;
DataServiceClient client(params);
EXPECT_THAT(
client.Initialize(/*allocator=*/nullptr),
client.Initialize(/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr),
StatusIs(
error::INVALID_ARGUMENT,
HasSubstr(
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/data_transfer.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class DataTransferClient {
struct Config {
absl::string_view protocol;
std::string address;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info;
Allocator* allocator;
};
using ClientFactoryT =
Expand Down
3 changes: 2 additions & 1 deletion tensorflow/core/data/service/test_cluster.h
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ DatasetClient<T>::DatasetClient(const TestCluster& cluster)
for (size_t i = 0; i < cluster.NumWorkers(); ++i) {
worker_clients_[cluster_.WorkerAddress(i)] =
std::make_unique<DataServiceWorkerClient>(
cluster_.WorkerAddress(i), "grpc", "grpc", /*allocator=*/nullptr);
cluster_.WorkerAddress(i), "grpc", "grpc",
/*accelerator_device_info=*/nullptr, /*allocator=*/nullptr);
}
}

Expand Down
14 changes: 9 additions & 5 deletions tensorflow/core/data/service/worker_client.cc
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@ limitations under the License.
#include "tensorflow/core/data/service/worker_impl.h"
#include "tensorflow/core/framework/allocator.h"
#include "tensorflow/core/framework/dataset.pb.h"
#include "tensorflow/core/framework/device_base.h"
#include "tensorflow/core/framework/metrics.h"
#include "tensorflow/core/framework/tensor.h"
#include "tensorflow/core/framework/tensor_shape.h"
Expand All @@ -56,11 +57,13 @@ namespace tensorflow {
namespace data {

StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator) {
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator) {
auto client = std::make_unique<DataServiceWorkerClient>(
info.address(), dispatcher_protocol, info.protocol(), allocator);
info.address(), dispatcher_protocol, info.protocol(),
accelerator_device_info, allocator);
TF_RETURN_IF_ERROR(client->Initialize());
TF_RETURN_WITH_CONTEXT_IF_ERROR(
client->CheckCompatibility(info.compatibility_info()),
Expand All @@ -82,7 +85,8 @@ Status DataServiceWorkerClient::EnsureInitialized() {
return absl::OkStatus();
}
TF_RETURN_IF_ERROR(DataTransferClient::Build(
GetDataTransferProtocol(), {protocol_, address_, allocator_}, &client_));
GetDataTransferProtocol(),
{protocol_, address_, accelerator_device_info_, allocator_}, &client_));
return absl::OkStatus();
}

Expand Down
18 changes: 11 additions & 7 deletions tensorflow/core/data/service/worker_client.h
Original file line number Diff line number Diff line change
Expand Up @@ -37,12 +37,14 @@ constexpr const char kGrpcTransferProtocol[] = "grpc";
// Client for communicating with the tf.data service worker.
class DataServiceWorkerClient : public DataServiceClientBase {
public:
DataServiceWorkerClient(const std::string& address,
const std::string& protocol,
const std::string& transfer_protocol,
Allocator* allocator)
DataServiceWorkerClient(
const std::string& address, const std::string& protocol,
const std::string& transfer_protocol,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator)
: DataServiceClientBase(address, protocol),
transfer_protocol_(transfer_protocol),
accelerator_device_info_(accelerator_device_info),
allocator_(allocator) {}

// Fetches an element from the worker.
Expand All @@ -66,6 +68,7 @@ class DataServiceWorkerClient : public DataServiceClientBase {

private:
std::string transfer_protocol_;
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info_;
Allocator* allocator_;

mutex mu_;
Expand All @@ -77,9 +80,10 @@ class DataServiceWorkerClient : public DataServiceClientBase {
// Creates and initializes a new tf.data service worker client to read
// from the data transfer server specified in `info`.
StatusOr<std::unique_ptr<DataServiceWorkerClient>>
CreateDataServiceWorkerClient(const std::string& dispatcher_protocol,
const DataTransferServerInfo& info,
Allocator* allocator);
CreateDataServiceWorkerClient(
const std::string& dispatcher_protocol, const DataTransferServerInfo& info,
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info,
Allocator* allocator);

} // namespace data
} // namespace tensorflow
Expand Down
1 change: 1 addition & 0 deletions tensorflow/core/data/service/worker_client_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ class WorkerClientTest : public ::testing::Test {
info.set_address(GetWorkerAddress());
info.set_protocol(data_transfer_protocol);
return CreateDataServiceWorkerClient(kProtocol, info,
/*accelerator_device_info=*/nullptr,
/*allocator=*/nullptr);
}

Expand Down
11 changes: 10 additions & 1 deletion tensorflow/core/framework/dataset.h
Original file line number Diff line number Diff line change
Expand Up @@ -667,7 +667,8 @@ class IteratorContext {
public:
struct Params {
explicit Params(IteratorContext* ctx)
: allocator_getter(ctx->allocator_getter()),
: accelerator_device_info(ctx->accelerator_device_info()),
allocator_getter(ctx->allocator_getter()),
cancellation_manager(ctx->cancellation_manager()),
collective_executor(ctx->collective_executor()),
env(ctx->env()),
Expand Down Expand Up @@ -697,6 +698,7 @@ class IteratorContext {
// NOTE: need reinterpret_cast because function.h forward-declares Device.
DeviceBase* device =
reinterpret_cast<DeviceBase*>(ctx->function_library()->device());
accelerator_device_info = device->tensorflow_accelerator_device_info();
allocator_getter = [device](AllocatorAttributes attrs) {
return device->GetAllocator(attrs);
};
Expand All @@ -719,6 +721,9 @@ class IteratorContext {
*ctx->runner(), std::placeholders::_1);
}

// If non-null, information about the GPU or TPU on which the op is placed.
const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info = nullptr;

// The Allocator to be used to allocate the output of an iterator.
std::function<Allocator*(AllocatorAttributes)> allocator_getter = nullptr;

Expand Down Expand Up @@ -825,6 +830,10 @@ class IteratorContext {
return params_.id_registry;
}

const DeviceBase::AcceleratorDeviceInfo* accelerator_device_info() {
return params_.accelerator_device_info;
}

Allocator* allocator(AllocatorAttributes attrs) {
return params_.allocator_getter(attrs);
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -340,7 +340,8 @@ class DataServiceDatasetOp::Dataset : public DatasetBase {
TF_RETURN_IF_ERROR(RegisterCancellationCallback(
ctx->cancellation_manager(),
[this]() { data_service_client_.Cancel(); }, &deregister_fn_));
return data_service_client_.Initialize(ctx->allocator(/*attrs=*/{}));
return data_service_client_.Initialize(ctx->accelerator_device_info(),
ctx->allocator(/*attrs=*/{}));
}

Status GetNextInternal(IteratorContext* ctx,
Expand Down
2 changes: 1 addition & 1 deletion tensorflow/core/profiler/protobuf/task.proto
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@ option cc_enable_arenas = true;
// 'Task' contains information about a task that profiler traced.
message Task {
// The most recent changelist number from the client that built the binary.
optional int32 changelist = 1;
optional int64 changelist = 1;
// True if the client that built the binary was mint (no local changes).
optional bool clean_build = 2;
// Build time (in ns relative to the Unix epoch).
Expand Down
11 changes: 11 additions & 0 deletions third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
return success();
}

@@ -3861,8 +3819,8 @@
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
- int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t shapeCount = std::accumulate(
+ shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down
11 changes: 11 additions & 0 deletions third_party/xla/third_party/stablehlo/temporary.patch
Original file line number Diff line number Diff line change
Expand Up @@ -506,6 +506,17 @@ diff --ruN a/stablehlo/stablehlo/dialect/TypeInference.cpp b/stablehlo/stablehlo
return success();
}

@@ -3861,8 +3819,8 @@
if (SmallVector<int64_t> shape; operandType.hasStaticShape() &&
matchInts(outputShape, shape).succeeded()) {
int64_t operandCount = operandType.getNumElements();
- int64_t shapeCount = std::accumulate(shape.begin(), shape.end(), 1,
- std::multiplies<int64_t>());
+ int64_t shapeCount = std::accumulate(
+ shape.begin(), shape.end(), int64_t{1}, std::multiplies<int64_t>());
if (operandCount != shapeCount) {
return emitOptionalError(location,
"output_shape is incompatible with input type "
diff --ruN a/stablehlo/stablehlo/experimental/BUILD.bazel b/stablehlo/stablehlo/experimental/BUILD.bazel
--- stablehlo/stablehlo/experimental/BUILD.bazel
+++ stablehlo/stablehlo/experimental/BUILD.bazel
Expand Down
6 changes: 3 additions & 3 deletions third_party/xla/xla/service/host_offloader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -755,12 +755,12 @@ absl::StatusOr<bool> HostOffloader::Run(
// Run HloAliasAnalysis on module.
TF_ASSIGN_OR_RETURN(alias_analysis_, HloAliasAnalysis::Run(module));

// Handle streamed parameters first.
TF_RETURN_IF_ERROR(HandleInputStreaming(module->entry_computation()));

// Iterate over all instructions and look for XLA host offload annotations.
for (HloComputation* computation :
module->MakeNonfusionComputations(execution_threads)) {
if (computation->IsEntryComputation()) {
TF_RETURN_IF_ERROR(HandleInputStreaming(computation));
}
for (HloInstruction* instruction :
computation->MakeInstructionPostOrder()) {
if (instruction->opcode() != HloOpcode::kCustomCall) {
Expand Down