Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

#tf-data Add prefetching to WeightedFlatMap tests. #66542

Merged
merged 1 commit into from
May 1, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impls_(dataset()->inputs_.size()),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()),
element_count_(0),
inputs_element_count_(dataset()->inputs_.size(), 0),
next_positions_(dataset()->inputs_.size(), 0),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()) {}
input_impls_(dataset()->inputs_.size()) {}

bool SymbolicCheckpointCompatible() const override { return true; }

Expand Down Expand Up @@ -392,17 +392,18 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
}

private:
const std::vector<uint64_t> cumulative_input_cardinalities_;

mutable absl::Mutex mu_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
// Counts the number of elements this iterator has produced.
int64_t element_count_ ABSL_GUARDED_BY(mu_) = 0;
// Counts the number of elements each input iterator has produced.
std::vector<int64_t> inputs_element_count_ ABSL_GUARDED_BY(mu_);
// Keeps track of the position of this iterator that each input starts to
// scan for its next index.
std::vector<size_t> next_positions_;
const std::vector<uint64_t> cumulative_input_cardinalities_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
};

const std::vector<DatasetBase*> inputs_;
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/python/data/experimental/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,8 @@ tf_py_strict_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:options",
"//tensorflow/python/framework:combinations",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:errors",
"//tensorflow/python/framework:random_seed",
"//tensorflow/python/platform:client_testlib",
"@absl_py//absl/logging",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,34 @@ class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

@combinations.generate(test_base.default_test_combinations())
def testShuffledOutput(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))
dataset = global_shuffle_op._global_shuffle(dataset)

output = self.getDatasetOutput(dataset, requires_initialization=True)
self.assertCountEqual(
output, list(range(5)) + list(range(10, 15)) + list(range(20, 30)))

@combinations.generate(test_base.default_test_combinations())
def testShuffledInputs(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset1 = global_shuffle_op._global_shuffle(dataset1, seed=42)
dataset2 = global_shuffle_op._global_shuffle(dataset2, seed=42)
dataset3 = global_shuffle_op._global_shuffle(dataset3, seed=42)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))

output = self.getDatasetOutput(dataset, requires_initialization=True)
# Verifies that the first 5 elements are from `dataset1` in a random order.
self.assertFalse(set(output[:5]).issubset(set(range(5))))
Expand Down