Skip to content

Commit

Permalink
#tf-data Simplify the implementation for WeightedFlatMap global shuffle.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 628535728
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Apr 26, 2024
1 parent dcd302e commit 0b6dd79
Showing 1 changed file with 17 additions and 67 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -235,8 +235,6 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
: DatasetIterator<Dataset>(params),
input_impls_(dataset()->inputs_.size()),
element_count_(0),
inputs_element_count_(dataset()->inputs_.size(), 0),
next_positions_(dataset()->inputs_.size(), 0),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()) {}

Expand Down Expand Up @@ -273,64 +271,30 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
input_dataset_index =
IntervalIndex(cumulative_input_cardinalities_, parent_index);
IteratorContext::Params params(ctx);
params.index_mapper = GetWeightedFlatMapIndexMapper(
ctx->index_mapper(), input_dataset_index);
params.index_mapper =
GetWeightedFlatMapIndexMapper(parent_index, input_dataset_index);
IteratorContext global_shuffle_ctx(params);
TF_RETURN_IF_ERROR(input_impls_[input_dataset_index]->GetNext(
&global_shuffle_ctx, out_tensors, end_of_sequence));
ctx->MergeCheckpoint(global_shuffle_ctx.checkpoint());
}
++inputs_element_count_[input_dataset_index];
++element_count_;
return absl::OkStatus();
}

// Returns the index mapper for an input given its `input_dataset_index`.
IndexMapperFn GetWeightedFlatMapIndexMapper(
IndexMapperFn parent_index_mapper, size_t input_dataset_index = 0)
IndexMapperFn GetWeightedFlatMapIndexMapper(size_t parent_index,
size_t input_dataset_index)
ABSL_EXCLUSIVE_LOCKS_REQUIRED(mu_) {
size_t last_position = this->cumulative_input_cardinalities_.back();
return [this, parent_index_mapper, input_dataset_index, last_position](
return [this, parent_index, input_dataset_index](
size_t element_position) -> absl::StatusOr<size_t> {
// This index mapper function scans the position of the
// `WeightedFlatMap` elements to find the first element that matches the
// `input_dataset_index`. It updates this position each time the
// function is called so that it does not start from the beginning the
// next time it is called. For example, if there are 2 inputs: input0
// and input1 with elements [0, 1, 2], and [10, 11, 12] and the output
// is shuffled to return [1, 12, 10, 2, 11, 0]. The first time each
// input is called, the following is what each variable has before
// returning.
// input0 input1
// element_position 0 0
// index 1 1
// next_position 1 2 (next_position = 1 is skipped
// because it is for input0)
// index 1 (for 1) 2 (for 12)
// The second time around, input0 will start scanning from
// `next_positions_[0]`, which is 1, and input1 will start scanning from
// `next_positions_[1]`, which is 2.
while (this->next_positions_[input_dataset_index] < last_position) {
// `index` is the shuffled index of this dataset, not any of the
// inputs.
size_t index = this->next_positions_[input_dataset_index];
if (parent_index_mapper != nullptr) {
TF_ASSIGN_OR_RETURN(index, parent_index_mapper(index));
}
++(this->next_positions_[input_dataset_index]);
// Finds the shuffled `index` comes from dataset
// `input_dataset_index`, computes the local offset to the input and
// return the offset. If not, iterate to continue scanning.
if (IntervalIndex(this->cumulative_input_cardinalities_, index) ==
input_dataset_index) {
// Finds the offset in input `input_dataset_index`.
if (input_dataset_index > 0) {
index -= cumulative_input_cardinalities_[input_dataset_index - 1];
}
return index;
}
if (input_dataset_index == 0 ||
cumulative_input_cardinalities_.empty() ||
parent_index >= cumulative_input_cardinalities_.back()) {
return parent_index;
}
return last_position;
return parent_index -
cumulative_input_cardinalities_[input_dataset_index - 1];
};
}

Expand All @@ -339,10 +303,7 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
absl::MutexLock l(&mu_);
TF_RETURN_IF_ERROR(
writer->WriteScalar(prefix(), kInputNumElements, element_count_));
for (int i = 0; i < inputs_element_count_.size(); ++i) {
TF_RETURN_IF_ERROR(writer->WriteScalar(
prefix(), absl::StrCat(kInputNumElements, "[", i, "]"),
inputs_element_count_[i]));
for (int i = 0; i < input_impls_.size(); ++i) {
TF_RETURN_IF_ERROR(SaveInput(ctx, writer, input_impls_[i]));
}
return absl::OkStatus();
Expand All @@ -352,11 +313,9 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
IteratorStateReader* reader) override {
absl::MutexLock l(&mu_);
if (ctx->restored_element_count().has_value()) {
std::vector<int64_t> input_element_counts(dataset()->inputs_.size(), 0);
element_count_ = *ctx->restored_element_count();
// Restores all input's element counts and next positions.
std::fill(inputs_element_count_.begin(), inputs_element_count_.end(),
0);
std::fill(next_positions_.begin(), next_positions_.end(), 0);
for (int64_t count = 0; count < element_count_; ++count) {
if (element_count_ >= cumulative_input_cardinalities_.back()) {
break;
Expand All @@ -367,13 +326,12 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
}
auto input_dataset_index =
IntervalIndex(cumulative_input_cardinalities_, parent_index);
++inputs_element_count_[input_dataset_index];
next_positions_[input_dataset_index] = count + 1;
++input_element_counts[input_dataset_index];
}
// Restores all inputs.
for (int i = 0; i < inputs_element_count_.size(); ++i) {
for (int i = 0; i < input_element_counts.size(); ++i) {
IteratorContext::Params params(ctx);
params.restored_element_count = inputs_element_count_[i];
params.restored_element_count = input_element_counts[i];
IteratorContext ctx_copy(params);
TF_RETURN_IF_ERROR(RestoreInput(&ctx_copy, reader, input_impls_[i]));
ctx->MergeCheckpoint(ctx_copy.checkpoint());
Expand All @@ -382,11 +340,8 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
}
TF_RETURN_IF_ERROR(
reader->ReadScalar(prefix(), kInputNumElements, &element_count_));
for (int i = 0; i < inputs_element_count_.size(); ++i) {
for (int i = 0; i < input_impls_.size(); ++i) {
TF_RETURN_IF_ERROR(RestoreInput(ctx, reader, input_impls_[i]));
TF_RETURN_IF_ERROR(reader->ReadScalar(
prefix(), absl::StrCat(kInputNumElements, "[", i, "]"),
&inputs_element_count_[i]));
}
return absl::OkStatus();
}
Expand All @@ -397,11 +352,6 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
ABSL_GUARDED_BY(mu_);
// Counts the number of elements this iterator has produced.
int64_t element_count_ ABSL_GUARDED_BY(mu_) = 0;
// Counts the number of elements each input iterator has produced.
std::vector<int64_t> inputs_element_count_ ABSL_GUARDED_BY(mu_);
// Keeps track of the position of this iterator that each input starts to
// scan for its next index.
std::vector<size_t> next_positions_;
const std::vector<uint64_t> cumulative_input_cardinalities_;
};

Expand Down

0 comments on commit 0b6dd79

Please sign in to comment.