Skip to content

Commit

Permalink
#tf-data Add prefetching to WeightedFlatMap tests.
Browse files Browse the repository at this point in the history
PiperOrigin-RevId: 629563974
  • Loading branch information
yangustc07 authored and tensorflower-gardener committed Apr 30, 2024
1 parent bc19528 commit 2876b29
Show file tree
Hide file tree
Showing 3 changed files with 21 additions and 16 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -233,12 +233,12 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
public:
explicit Iterator(const Params& params)
: DatasetIterator<Dataset>(params),
input_impls_(dataset()->inputs_.size()),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()),
element_count_(0),
inputs_element_count_(dataset()->inputs_.size(), 0),
next_positions_(dataset()->inputs_.size(), 0),
cumulative_input_cardinalities_(
dataset()->ComputeCumulativeInputCardinalities()) {}
input_impls_(dataset()->inputs_.size()) {}

bool SymbolicCheckpointCompatible() const override { return true; }

Expand Down Expand Up @@ -392,17 +392,18 @@ class WeightedFlatMapDatasetOp::Dataset : public DatasetBase {
}

private:
const std::vector<uint64_t> cumulative_input_cardinalities_;

mutable absl::Mutex mu_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
// Counts the number of elements this iterator has produced.
int64_t element_count_ ABSL_GUARDED_BY(mu_) = 0;
// Counts the number of elements each input iterator has produced.
std::vector<int64_t> inputs_element_count_ ABSL_GUARDED_BY(mu_);
// Keeps track of the position of this iterator that each input starts to
// scan for its next index.
std::vector<size_t> next_positions_;
const std::vector<uint64_t> cumulative_input_cardinalities_;
std::vector<std::unique_ptr<IteratorBase>> input_impls_
ABSL_GUARDED_BY(mu_);
};

const std::vector<DatasetBase*> inputs_;
Expand Down
4 changes: 0 additions & 4 deletions tensorflow/python/data/experimental/kernel_tests/BUILD
Original file line number Diff line number Diff line change
Expand Up @@ -758,12 +758,8 @@ tf_py_strict_test(
"//tensorflow/python/data/ops:dataset_ops",
"//tensorflow/python/data/ops:options",
"//tensorflow/python/framework:combinations",
"//tensorflow/python/framework:constant_op",
"//tensorflow/python/framework:dtypes",
"//tensorflow/python/framework:errors",
"//tensorflow/python/framework:random_seed",
"//tensorflow/python/platform:client_testlib",
"@absl_py//absl/logging",
"@absl_py//absl/testing:parameterized",
],
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -97,26 +97,34 @@ class GlobalShuffleTest(test_base.DatasetTestBase, parameterized.TestCase):

@combinations.generate(test_base.default_test_combinations())
def testShuffledOutput(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))
dataset = global_shuffle_op._global_shuffle(dataset)

output = self.getDatasetOutput(dataset, requires_initialization=True)
self.assertCountEqual(
output, list(range(5)) + list(range(10, 15)) + list(range(20, 30)))

@combinations.generate(test_base.default_test_combinations())
def testShuffledInputs(self):
dataset1 = dataset_ops.Dataset.range(10)
dataset2 = dataset_ops.Dataset.range(10, 20)
dataset3 = dataset_ops.Dataset.range(20, 30)
dataset1 = dataset_ops.Dataset.range(10).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset2 = dataset_ops.Dataset.range(10, 20).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset3 = dataset_ops.Dataset.range(20, 30).prefetch(
buffer_size=dataset_ops.AUTOTUNE)
dataset1 = global_shuffle_op._global_shuffle(dataset1, seed=42)
dataset2 = global_shuffle_op._global_shuffle(dataset2, seed=42)
dataset3 = global_shuffle_op._global_shuffle(dataset3, seed=42)
dataset = weighted_flat_map_op._weighted_flat_map(
[dataset1, dataset2, dataset3], np.asarray([0.25, 0.25, 0.5]))

output = self.getDatasetOutput(dataset, requires_initialization=True)
# Verifies that the first 5 elements are from `dataset1` in a random order.
self.assertFalse(set(output[:5]).issubset(set(range(5))))
Expand Down

0 comments on commit 2876b29

Please sign in to comment.