Skip to content

Commit

Permalink
#tf-data Add prefetching to WeightedFlatMap tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628535728
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Apr 30, 2024
1 parent 0278194 commit 2bf1704
Show file tree
Hide file tree
Showing 25 changed files with 481 additions and 113 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impls_(dataset()->inputs_.size()),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()),
element_count_(0),
inputs_element_count_(dataset()->inputs_.size(), 0),
next_positions_(dataset()->inputs_.size(), 0),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()) {}
input_impls_(dataset()->inputs_.size()) {}

bool SymbolicCheckpointCompatible() const override { return true; }

Expand Down Expand Up @@ -392,17 +392,18 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
}

private:
const std::vector<uint64_t> cumulative_input_cardinalities_;

mutable absl::Mutex mu_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
// Counts the number of elements this iterator has produced.
int64_t element_count_ ABSL_GUARDED_BY(mu_) = 0;
// Counts the number of elements each input iterator has produced.
std::vector<int64_t> inputs_element_count_ ABSL_GUARDED_BY(mu_);
// Keeps track of the position of this iterator that each input starts to
// scan for its next index.
std::vector<size_t> next_positions_;
const std::vector<uint64_t> cumulative_input_cardinalities_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
};

const std::vector<DatasetBase*> inputs_;
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/python/data/experimental/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,8 @@ tf_py_strict_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:options",
"//tensorflow/python/framework:combinations",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:errors",
"//tensorflow/python/framework:random_seed",
"//tensorflow/python/platform:client_testlib",
"@absl_py//absl/logging",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,34 @@ class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

@combinations.generate(test_base.default_test_combinations())
def testShuffledOutput(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))
dataset = global_shuffle_op._global_shuffle(dataset)

output = self.getDatasetOutput(dataset, requires_initialization=True)
self.assertCountEqual(
output, list(range(5)) + list(range(10, 15)) + list(range(20, 30)))

@combinations.generate(test_base.default_test_combinations())
def testShuffledInputs(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset1 = global_shuffle_op._global_shuffle(dataset1, seed=42)
dataset2 = global_shuffle_op._global_shuffle(dataset2, seed=42)
dataset3 = global_shuffle_op._global_shuffle(dataset3, seed=42)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))

output = self.getDatasetOutput(dataset, requires_initialization=True)
# Verifies that the first 5 elements are from `dataset1` in a random order.
self.assertFalse(set(output[:5]).issubset(set(range(5))))
Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -644,6 +644,7 @@ xla_cc_test(
":util",
":xla_data_proto_cc",
"@com_google_absl//absl/base",
"@com_google_absl//absl/hash",
"@com_google_absl//absl/random",
"@com_google_absl//absl/strings",
"@com_google_absl//absl/types:span",
Expand Down
30 changes: 28 additions & 2 deletions third_party/xla/xla/literal.h
Original file line number Diff line number Diff line change
Expand Up @@ -365,10 +365,36 @@ class LiteralBase {
}

CHECK(LayoutUtil::IsDenseArray(subshape));
const int64_t size_bytes = literal.size_bytes(index);
const int64_t bytes_to_hash = std::min(size_bytes, kByteLimit);
// When layout insensitive, we need to hash the data bytes in logical
// order rather than physical order.
const bool use_physical_order =
kIsLayoutSensitive || !subshape.has_layout();
auto data = absl::MakeConstSpan(
static_cast<const char*>(literal.untyped_data(index)),
std::min(kByteLimit, literal.size_bytes(index)));
state = H::combine(std::move(state), data);
size_bytes);
if (use_physical_order) {
state = H::combine(std::move(state), data.first(bytes_to_hash));
return;
}
const int64_t elem_size =
ShapeUtil::ByteSizeOfPrimitiveType(subshape.element_type());
absl::Span<const int64_t> minor_to_major =
subshape.layout().minor_to_major();
DimensionVector elem_index(subshape.dimensions_size());
absl::Span<int64_t> elem_index_span(elem_index.data(),
elem_index.size());
int64_t bytes_hashed = 0;
while (bytes_hashed < bytes_to_hash) {
int64_t offset =
elem_size * IndexUtil::MultidimensionalIndexToLinearIndex(
subshape, minor_to_major, elem_index);
state =
H::combine(std::move(state), data.subspan(offset, elem_size));
if (!IndexUtil::BumpIndices(subshape, elem_index_span)) return;
bytes_hashed += elem_size;
}
});

return std::move(state);
Expand Down
21 changes: 21 additions & 0 deletions third_party/xla/xla/literal_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@ limitations under the License.
#include <vector>

#include "absl/base/casts.h"
#include "absl/hash/hash.h"
#include "absl/random/random.h"
#include "absl/strings/match.h"
#include "absl/types/span.h"
Expand Down Expand Up @@ -931,6 +932,16 @@ TEST_F(LiteralUtilTest, TestR4RelayoutEquivalence) {
EXPECT_EQ(literal_r4_2x2x3x3_dim0minor_, dim0major_relaid_to_dim0minor);
}

template <bool kIsLayoutSensitive>
struct HashTester {
template <typename H>
friend H AbslHashValue(H h, const HashTester& key) {
return Literal::Hash<H, kIsLayoutSensitive, /*kByteLimit=*/64>(
std::move(h), *key.literal);
}
const Literal* literal;
};

TEST_F(LiteralUtilTest, TestR2LinearLayout) {
// Test expected memory layout of R2 dim0-minor (column-major) literal.
auto mat_dim0minor = LiteralUtil::CreateR2WithLayout<int32_t>(
Expand All @@ -942,6 +953,8 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) {
auto relaid_mat_to_dim0major = mat_dim0minor.Relayout(layout_r2_dim0major_);
EXPECT_THAT(relaid_mat_to_dim0major.data<int32_t>(),
ElementsAre(1, 2, 3, 4, 5, 6));
EXPECT_EQ(absl::HashOf(HashTester<false>{&mat_dim0minor}),
absl::HashOf(HashTester<false>{&relaid_mat_to_dim0major}));

// Test expected memory layout of R2 created with dim0-major (row-major).
auto mat_dim0major = LiteralUtil::CreateR2WithLayout<int32_t>(
Expand All @@ -953,6 +966,14 @@ TEST_F(LiteralUtilTest, TestR2LinearLayout) {
auto relaid_mat_to_dim0minor = mat_dim0major.Relayout(layout_r2_dim0minor_);
EXPECT_THAT(relaid_mat_to_dim0minor.data<int32_t>(),
ElementsAre(1, 4, 2, 5, 3, 6));
EXPECT_EQ(absl::HashOf(HashTester<false>{&mat_dim0major}),
absl::HashOf(HashTester<false>{&relaid_mat_to_dim0minor}));

// Test that layout sensitive hashes are equal.
EXPECT_EQ(absl::HashOf(HashTester<true>{&mat_dim0minor}),
absl::HashOf(HashTester<true>{&relaid_mat_to_dim0minor}));
EXPECT_EQ(absl::HashOf(HashTester<true>{&mat_dim0major}),
absl::HashOf(HashTester<true>{&relaid_mat_to_dim0major}));
}

TEST_F(LiteralUtilTest, TestR3LinearLayout) {
Expand Down
3 changes: 3 additions & 0 deletions third_party/xla/xla/service/gpu/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -4064,6 +4064,7 @@ cc_library(
"//xla:types",
"//xla:util",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_handle",
"//xla/stream_executor:stream_executor_headers",
"//xla/stream_executor/gpu:gpu_executor_header",
"@com_google_absl//absl/base:core_headers",
Expand Down Expand Up @@ -4440,6 +4441,7 @@ cc_library(
"//xla:util",
"//xla/service:hlo_module_config",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_handle",
"//xla/stream_executor/gpu:asm_compiler",
"@local_tsl//tsl/platform:errors",
"@local_tsl//tsl/platform:logging",
Expand Down Expand Up @@ -4476,6 +4478,7 @@ xla_cc_test(
"//xla/service:hlo_module_config",
"//xla/stream_executor",
"//xla/stream_executor:device_memory_allocator",
"//xla/stream_executor:device_memory_handle",
"//xla/stream_executor:platform_manager",
"@local_tsl//tsl/platform:ml_dtypes",
"@local_tsl//tsl/platform:status",
Expand Down
16 changes: 9 additions & 7 deletions third_party/xla/xla/service/gpu/buffer_comparator.cc
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/stream_executor/device_description.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/device_memory_handle.h"
#include "xla/stream_executor/kernel.h"
#include "xla/stream_executor/stream_executor.h"
#include "xla/util.h"
Expand Down Expand Up @@ -60,10 +60,10 @@ static absl::StatusOr<bool> DeviceCompare(se::Stream* stream,
void* kernel_symbol) {
se::StreamExecutor* executor = stream->parent();

se::ScopedDeviceMemory<uint64_t> out_param(
executor, executor->AllocateScalar<uint64_t>());
se::DeviceMemoryHandle out_param(executor,
executor->AllocateScalar<uint64_t>());

TF_RETURN_IF_ERROR(stream->MemZero(out_param.ptr(), sizeof(uint64_t)));
TF_RETURN_IF_ERROR(stream->MemZero(out_param.memory_ptr(), sizeof(uint64_t)));
if (current.size() != expected.size()) {
return Internal("Mismatched buffer size: %d bytes vs. %d bytes",
current.size(), expected.size());
Expand All @@ -87,14 +87,16 @@ static absl::StatusOr<bool> DeviceCompare(se::Stream* stream,
LaunchDimensions dim =
CalculateLaunchDimensions(buffer_shape, gpu_device_info);

se::DeviceMemory<uint64_t> as_uint64(out_param.memory());
TF_RETURN_IF_ERROR(stream->ThenLaunch(
dim.thread_counts_per_block(), dim.block_counts(), comparison_kernel,
current_typed, expected_typed, static_cast<float>(kTolerance),
buffer_size, out_param.cref()));
buffer_size, as_uint64));

uint64_t result = -1;
CHECK_EQ(out_param->size(), sizeof(result));
TF_RETURN_IF_ERROR(stream->Memcpy(&result, *out_param, sizeof(result)));
CHECK_EQ(out_param.memory().size(), sizeof(result));
TF_RETURN_IF_ERROR(
stream->Memcpy(&result, out_param.memory(), sizeof(result)));
TF_RETURN_IF_ERROR(stream->BlockHostUntilDone());
return result == 0;
}
Expand Down
19 changes: 10 additions & 9 deletions third_party/xla/xla/service/gpu/buffer_comparator_test.cc
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ limitations under the License.
#include "xla/service/hlo_module_config.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/device_memory_handle.h"
#include "xla/stream_executor/platform.h"
#include "xla/stream_executor/platform_manager.h"
#include "xla/stream_executor/stream.h"
Expand Down Expand Up @@ -56,25 +56,26 @@ class BufferComparatorTest : public testing::Test {
const std::vector<ElementType>& expected) {
auto stream = stream_exec_->CreateStream().value();

se::ScopedDeviceMemory<ElementType> current_buffer(
se::DeviceMemoryHandle current_buffer(
stream_exec_, stream_exec_->AllocateArray<ElementType>(current.size()));
se::ScopedDeviceMemory<ElementType> expected_buffer(
se::DeviceMemoryHandle expected_buffer(
stream_exec_,
stream_exec_->AllocateArray<ElementType>(expected.size()));

TF_CHECK_OK(stream->Memcpy(current_buffer.ptr(), current.data(),
current_buffer->size()));
TF_CHECK_OK(stream->Memcpy(expected_buffer.ptr(), expected.data(),
expected_buffer->size()));
TF_CHECK_OK(stream->Memcpy(current_buffer.memory_ptr(), current.data(),
current_buffer.memory().size()));
TF_CHECK_OK(stream->Memcpy(expected_buffer.memory_ptr(), expected.data(),
expected_buffer.memory().size()));
TF_CHECK_OK(stream->BlockHostUntilDone());

BufferComparator comparator(
ShapeUtil::MakeShape(
primitive_util::NativeToPrimitiveType<ElementType>(),
{static_cast<int64_t>(current_buffer->ElementCount())}),
{static_cast<int64_t>(current.size())}),
HloModuleConfig());
return comparator
.CompareEqual(stream.get(), *current_buffer, *expected_buffer)
.CompareEqual(stream.get(), current_buffer.memory(),
expected_buffer.memory())
.value();
}

Expand Down
12 changes: 6 additions & 6 deletions third_party/xla/xla/service/gpu/infeed_manager.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ limitations under the License.
#include "xla/shape.h"
#include "xla/shape_tree.h"
#include "xla/shape_util.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/device_memory_handle.h"
#include "xla/util.h"
#include "tsl/platform/errors.h"
#include "tsl/platform/statusor.h"
Expand All @@ -48,7 +48,7 @@ InfeedManager::InfeedManager(se::StreamExecutor* executor)
: BlockingXfeedQueue(/*max_pending_xfeeds=*/kMaxInfeedsInFlight),
stream_(executor->CreateStream().value()) {}

static absl::StatusOr<se::ScopedDeviceMemory<uint8_t>> CopyBufferToDevice(
static absl::StatusOr<se::DeviceMemoryHandle> CopyBufferToDevice(
se::Stream* stream, int64_t size, const void* source) {
if (size > std::numeric_limits<int32_t>::max()) {
return InvalidArgument("GPU infeed of %d bytes exceeds maximum of %d bytes",
Expand All @@ -60,9 +60,9 @@ static absl::StatusOr<se::ScopedDeviceMemory<uint8_t>> CopyBufferToDevice(
}

se::StreamExecutor* executor = stream->parent();
se::ScopedDeviceMemory<uint8_t> buffer(
executor, executor->AllocateArray<uint8_t>(size));
TF_RETURN_IF_ERROR(stream->Memcpy(buffer.ptr(), source, size));
se::DeviceMemoryHandle buffer(executor,
executor->AllocateArray<uint8_t>(size));
TF_RETURN_IF_ERROR(stream->Memcpy(buffer.memory_ptr(), source, size));

return std::move(buffer);
}
Expand All @@ -77,7 +77,7 @@ absl::Status InfeedManager::TransferLiteralToInfeed(

// For a tuple, we transfer each of its elements to the device and enqueue the
// resulting destination device addresses with the infeed manager.
ShapeTree<se::ScopedDeviceMemory<uint8_t>> buffer_tree(literal_shape);
ShapeTree<se::DeviceMemoryHandle> buffer_tree(literal_shape);
for (auto& leaf : buffer_tree.leaves()) {
const Shape& sub_shape = ShapeUtil::GetSubshape(literal_shape, leaf.first);
CHECK(sub_shape.IsArray()) << ShapeUtil::HumanStringWithLayout(sub_shape);
Expand Down
4 changes: 2 additions & 2 deletions third_party/xla/xla/service/gpu/infeed_manager.h
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ limitations under the License.
#include "xla/literal.h"
#include "xla/service/gpu/xfeed_queue.h"
#include "xla/shape_tree.h"
#include "xla/stream_executor/device_memory_allocator.h"
#include "xla/stream_executor/device_memory_handle.h"
#include "xla/stream_executor/stream_executor.h"

namespace xla {
Expand All @@ -47,7 +47,7 @@ namespace gpu {

// Client-side class used to enqueue infeed buffers.
class InfeedManager
: public BlockingXfeedQueue<ShapeTree<se::ScopedDeviceMemory<uint8_t>>> {
: public BlockingXfeedQueue<ShapeTree<se::DeviceMemoryHandle>> {
public:
explicit InfeedManager(se::StreamExecutor* executor);

Expand Down
1 change: 1 addition & 0 deletions third_party/xla/xla/service/gpu/kernels/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -179,6 +179,7 @@ xla_cc_test(
"//xla:types",
"//xla:xla_data_proto_cc",
"//xla/stream_executor", # build_cleaner: keep
"//xla/stream_executor:device_memory_handle",
"//xla/stream_executor:platform_manager",
"//xla/stream_executor/gpu:gpu_init",
"//xla/stream_executor/gpu:gpu_stream_header",
Expand Down

0 comments on commit 2bf1704

Please sign in to comment.