/
dataset_ops.py
2101 lines (1721 loc) · 76.6 KB
/
dataset_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2017 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Python wrappers for Datasets."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import abc
import collections
import threading
import numpy as np
import six
from tensorflow.python.data.ops import iterator_ops
from tensorflow.python.data.util import nest
from tensorflow.python.data.util import random_seed
from tensorflow.python.data.util import sparse
from tensorflow.python.eager import context
from tensorflow.python.framework import constant_op
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import function
from tensorflow.python.framework import ops
from tensorflow.python.framework import sparse_tensor as sparse_tensor_lib
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import gen_dataset_ops
from tensorflow.python.ops import gen_io_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import script_ops
from tensorflow.python.util import deprecation
from tensorflow.python.util.tf_export import tf_export
@tf_export("data.Dataset")
class Dataset(object):
"""Represents a potentially large set of elements.
A `Dataset` can be used to represent an input pipeline as a
collection of elements (nested structures of tensors) and a "logical
plan" of transformations that act on those elements.
"""
__metaclass__ = abc.ABCMeta
def __init__(self):
pass
@abc.abstractmethod
def _as_variant_tensor(self):
"""Creates a scalar `tf.Tensor` of `tf.variant` representing this dataset.
Returns:
A scalar `tf.Tensor` of `tf.variant` type, which represents this dataset.
"""
raise NotImplementedError("Dataset._as_variant_tensor")
def make_initializable_iterator(self, shared_name=None):
"""Creates an `Iterator` for enumerating the elements of this dataset.
Note: The returned iterator will be in an uninitialized state,
and you must run the `iterator.initializer` operation before using it:
```python
dataset = ...
iterator = dataset.make_initializable_iterator()
# ...
sess.run(iterator.initializer)
```
Args:
shared_name: (Optional.) If non-empty, the returned iterator will be
shared under the given name across multiple sessions that share the
same devices (e.g. when using a remote server).
Returns:
An `Iterator` over the elements of this dataset.
Raises:
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError(
"dataset.make_initializable_iterator is not supported when eager "
"execution is enabled.")
if shared_name is None:
shared_name = ""
iterator_resource = gen_dataset_ops.iterator(
container="",
shared_name=shared_name,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
with ops.colocate_with(iterator_resource):
initializer = gen_dataset_ops.make_iterator(self._as_variant_tensor(),
iterator_resource)
return iterator_ops.Iterator(iterator_resource, initializer,
self.output_types, self.output_shapes,
self.output_classes)
def make_one_shot_iterator(self):
"""Creates an `Iterator` for enumerating the elements of this dataset.
Note: The returned iterator will be initialized automatically.
A "one-shot" iterator does not currently support re-initialization.
Returns:
An `Iterator` over the elements of this dataset.
Raises:
RuntimeError: If eager execution is enabled.
"""
if context.executing_eagerly():
raise RuntimeError(
"dataset.make_one_shot_iterator is not supported when eager "
"execution is enabled.")
# NOTE(mrry): We capture by value here to ensure that `_make_dataset()` is
# a 0-argument function.
@function.Defun(capture_by_value=True)
def _make_dataset():
return self._as_variant_tensor() # pylint: disable=protected-access
try:
_make_dataset.add_to_graph(ops.get_default_graph())
except ValueError as err:
if "Cannot capture a stateful node" in str(err):
raise ValueError(
"Failed to create a one-shot iterator for a dataset. "
"`Dataset.make_one_shot_iterator()` does not support datasets that "
"capture stateful objects, such as a `Variable` or `LookupTable`. "
"In these cases, use `Dataset.make_initializable_iterator()`. "
"(Original error: %s)" % err)
else:
six.reraise(ValueError, err)
return iterator_ops.Iterator(
gen_dataset_ops.one_shot_iterator(
dataset_factory=_make_dataset,
output_types=nest.flatten(
sparse.as_dense_types(self.output_types, self.output_classes)),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes,
self.output_classes))), None,
self.output_types, self.output_shapes, self.output_classes)
@abc.abstractproperty
def output_classes(self):
"""Returns the class of each component of an element of this dataset.
The expected values are `tf.Tensor` and `tf.SparseTensor`.
Returns:
A nested structure of Python `type` objects corresponding to each
component of an element of this dataset.
"""
raise NotImplementedError("Dataset.output_classes")
@abc.abstractproperty
def output_shapes(self):
"""Returns the shape of each component of an element of this dataset.
Returns:
A nested structure of `tf.TensorShape` objects corresponding to each
component of an element of this dataset.
"""
raise NotImplementedError("Dataset.output_shapes")
@abc.abstractproperty
def output_types(self):
"""Returns the type of each component of an element of this dataset.
Returns:
A nested structure of `tf.DType` objects corresponding to each component
of an element of this dataset.
"""
raise NotImplementedError("Dataset.output_types")
def __repr__(self):
output_shapes = nest.map_structure(str, self.output_shapes)
output_shapes = str(output_shapes).replace("'", "")
output_types = nest.map_structure(repr, self.output_types)
output_types = str(output_types).replace("'", "")
return ("<%s shapes: %s, types: %s>" % (type(self).__name__, output_shapes,
output_types))
@staticmethod
def from_tensors(tensors):
"""Creates a `Dataset` with a single element, comprising the given tensors.
Args:
tensors: A nested structure of tensors.
Returns:
Dataset: A `Dataset`.
"""
return TensorDataset(tensors)
@staticmethod
def from_tensor_slices(tensors):
"""Creates a `Dataset` whose elements are slices of the given tensors.
Args:
tensors: A nested structure of tensors, each having the same size in the
0th dimension.
Returns:
Dataset: A `Dataset`.
"""
return TensorSliceDataset(tensors)
@staticmethod
@deprecation.deprecated(None, "Use `tf.data.Dataset.from_tensor_slices()`.")
def from_sparse_tensor_slices(sparse_tensor):
"""Splits each rank-N `tf.SparseTensor` in this dataset row-wise.
Args:
sparse_tensor: A `tf.SparseTensor`.
Returns:
Dataset: A `Dataset` of rank-(N-1) sparse tensors.
"""
return SparseTensorSliceDataset(sparse_tensor)
class _GeneratorState(object):
"""Stores outstanding iterators created from a Python generator.
This class keeps track of potentially multiple iterators that may have
been created from a generator, e.g. in the case that the dataset is
repeated, or nested within a parallel computation.
"""
def __init__(self, generator):
self._generator = generator
self._lock = threading.Lock()
self._next_id = 0 # GUARDED_BY(self._lock)
self._iterators = collections.defaultdict(lambda: iter(generator()))
def get_next_id(self):
with self._lock:
ret = self._next_id
self._next_id += 1
# NOTE(mrry): Explicitly create an array of `np.int64` because implicit
# casting in `py_func()` will create an array of `np.int32` on Windows,
# leading to a runtime error.
return np.array(ret, dtype=np.int64)
def get_iterator(self, iterator_id):
return self._iterators[iterator_id]
def iterator_completed(self, iterator_id):
del self._iterators[iterator_id]
@staticmethod
def from_generator(generator, output_types, output_shapes=None):
"""Creates a `Dataset` whose elements are generated by `generator`.
The `generator` argument must be a callable object that returns
an object that support the `iter()` protocol (e.g. a generator function).
The elements generated by `generator` must be compatible with the given
`output_types` and (optional) `output_shapes` arguments.
For example:
```python
import itertools
def gen():
for i in itertools.count(1):
yield (i, [1] * i)
ds = Dataset.from_generator(
gen, (tf.int64, tf.int64), (tf.TensorShape([]), tf.TensorShape([None])))
value = ds.make_one_shot_iterator().get_next()
sess.run(value) # (1, array([1]))
sess.run(value) # (2, array([1, 1]))
```
NOTE: The current implementation of `Dataset.from_generator()` uses
@{tf.py_func} and inherits the same constraints. In particular, it
requires the `Dataset`- and `Iterator`-related operations to be placed
on a device in the same process as the Python program that called
`Dataset.from_generator()`. The body of `generator` will not be
serialized in a `GraphDef`, and you should not use this method if you
need to serialize your model and restore it in a different environment.
NOTE: If `generator` depends on mutable global variables or other external
state, be aware that the runtime may invoke `generator` multiple times
(in order to support repeating the `Dataset`) and at any time
between the call to `Dataset.from_generator()` and the production of the
first element from the generator. Mutating global variables or external
state can cause undefined behavior, and we recommend that you explicitly
cache any external state in `generator` before calling
`Dataset.from_generator()`.
Args:
generator: A callable object that takes no arguments and returns an
object that supports the `iter()` protocol.
output_types: A nested structure of `tf.DType` objects corresponding to
each component of an element yielded by `generator`.
output_shapes: (Optional.) A nested structure of `tf.TensorShape`
objects corresponding to each component of an element yielded by
`generator`.
Returns:
Dataset: A `Dataset`.
"""
if not callable(generator):
raise TypeError("`generator` must be callable.")
if output_shapes is None:
output_shapes = nest.map_structure(
lambda _: tensor_shape.TensorShape(None), output_types)
else:
output_shapes = nest.map_structure_up_to(
output_types, tensor_shape.as_shape, output_shapes)
flattened_types = nest.flatten(output_types)
flattened_shapes = nest.flatten(output_shapes)
generator_state = Dataset._GeneratorState(generator)
def get_iterator_id_fn(unused_dummy):
"""Creates a unique `iterator_id` for each pass over the dataset.
The returned `iterator_id` disambiguates between multiple concurrently
existing iterators.
Args:
unused_dummy: Ignored value.
Returns:
A `tf.int64` tensor whose value uniquely identifies an iterator in
`generator_state`.
"""
return script_ops.py_func(
generator_state.get_next_id, [], dtypes.int64, stateful=True)
def generator_next_fn(iterator_id_t):
"""Generates the next element from iterator with ID `iterator_id_t`.
We map this function across an infinite repetition of the
`iterator_id_t`, and raise `StopIteration` to terminate the iteration.
Args:
iterator_id_t: A `tf.int64` tensor whose value uniquely identifies
the iterator in `generator_state` from which to generate an element.
Returns:
A nested structure of tensors representing an element from the iterator.
"""
def generator_py_func(iterator_id):
"""A `py_func` that will be called to invoke the iterator."""
# `next()` raises `StopIteration` when there are no more
# elements remaining to be generated.
values = next(generator_state.get_iterator(iterator_id))
# Use the same _convert function from the py_func() implementation to
# convert the returned values to arrays early, so that we can inspect
# their values.
# pylint: disable=protected-access
ret_arrays = [
script_ops.FuncRegistry._convert(ret, dtype=dtype.as_numpy_dtype)
for ret, dtype in zip(
nest.flatten_up_to(output_types, values), flattened_types)
]
# pylint: enable=protected-access
# Additional type and shape checking to ensure that the components
# of the generated element match the `output_types` and `output_shapes`
# arguments.
for (ret_array, expected_dtype, expected_shape) in zip(
ret_arrays, flattened_types, flattened_shapes):
if ret_array.dtype != expected_dtype.as_numpy_dtype:
raise TypeError(
"`generator` yielded an element of type %s where an element "
"of type %s was expected." % (ret_array.dtype,
expected_dtype.as_numpy_dtype))
if not expected_shape.is_compatible_with(ret_array.shape):
raise ValueError(
"`generator` yielded an element of shape %s where an element "
"of shape %s was expected." % (ret_array.shape, expected_shape))
return ret_arrays
flat_values = script_ops.py_func(
generator_py_func, [iterator_id_t], flattened_types, stateful=True)
# The `py_func()` op drops the inferred shapes, so we add them back in
# here.
if output_shapes is not None:
for ret_t, shape in zip(flat_values, flattened_shapes):
ret_t.set_shape(shape)
return nest.pack_sequence_as(output_types, flat_values)
def finalize_fn(iterator_id_t):
"""Releases host-side state for the iterator with ID `iterator_id_t`."""
def finalize_py_func(iterator_id):
generator_state.iterator_completed(iterator_id)
# We return a dummy value so that the `finalize_fn` has a valid
# signature.
# NOTE(mrry): Explicitly create an array of `np.int64` because implicit
# casting in `py_func()` will create an array of `np.int32` on Windows,
# leading to a runtime error.
return np.array(0, dtype=np.int64)
return script_ops.py_func(
finalize_py_func, [iterator_id_t], dtypes.int64, stateful=True)
# This function associates each traversal of `generator` with a unique
# iterator ID.
def flat_map_fn(dummy_arg):
# The `get_iterator_id_fn` gets a unique ID for the current instance of
# of the generator.
# The `generator_next_fn` gets the next element from the iterator with the
# given ID, and raises StopIteration when that iterator contains no
# more elements.
return _GeneratorDataset(dummy_arg, get_iterator_id_fn, generator_next_fn,
finalize_fn)
# A single-element dataset that, each time it is evaluated, contains a
# freshly-generated and unique (for the returned dataset) int64
# ID that will be used to identify the appropriate Python state, which
# is encapsulated in `generator_state`, and captured in
# `get_iterator_id_map_fn`.
dummy = 0
id_dataset = Dataset.from_tensors(dummy)
# A dataset that contains all of the elements generated by a
# single iterator created from `generator`, identified by the
# iterator ID contained in `id_dataset`. Lifting the iteration
# into a flat_map here enables multiple repetitions and/or nested
# versions of the returned dataset to be created, because it forces
# the generation of a new ID for each version.
return id_dataset.flat_map(flat_map_fn)
@staticmethod
def range(*args):
"""Creates a `Dataset` of a step-separated range of values.
For example:
```python
Dataset.range(5) == [0, 1, 2, 3, 4]
Dataset.range(2, 5) == [2, 3, 4]
Dataset.range(1, 5, 2) == [1, 3]
Dataset.range(1, 5, -2) == []
Dataset.range(5, 1) == []
Dataset.range(5, 1, -2) == [5, 3]
```
Args:
*args: follow same semantics as python's xrange.
len(args) == 1 -> start = 0, stop = args[0], step = 1
len(args) == 2 -> start = args[0], stop = args[1], step = 1
len(args) == 3 -> start = args[0], stop = args[1, stop = args[2]
Returns:
Dataset: A `RangeDataset`.
Raises:
ValueError: if len(args) == 0.
"""
return RangeDataset(*args)
@staticmethod
def zip(datasets):
"""Creates a `Dataset` by zipping together the given datasets.
This method has similar semantics to the built-in `zip()` function
in Python, with the main difference being that the `datasets`
argument can be an arbitrary nested structure of `Dataset` objects.
For example:
```python
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { 4, 5, 6 }
c = { (7, 8), (9, 10), (11, 12) }
d = { 13, 14 }
# The nested structure of the `datasets` argument determines the
# structure of elements in the resulting dataset.
Dataset.zip((a, b)) == { (1, 4), (2, 5), (3, 6) }
Dataset.zip((b, a)) == { (4, 1), (5, 2), (6, 3) }
# The `datasets` argument may contain an arbitrary number of
# datasets.
Dataset.zip((a, b, c)) == { (1, 4, (7, 8)),
(2, 5, (9, 10)),
(3, 6, (11, 12)) }
# The number of elements in the resulting dataset is the same as
# the size of the smallest dataset in `datasets`.
Dataset.zip((a, d)) == { (1, 13), (2, 14) }
```
Args:
datasets: A nested structure of datasets.
Returns:
Dataset: A `Dataset`.
"""
return ZipDataset(datasets)
def concatenate(self, dataset):
"""Creates a `Dataset` by concatenating given dataset with this dataset.
```python
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3 }
b = { 4, 5, 6, 7 }
# Input dataset and dataset to be concatenated should have same
# nested structures and output types.
# c = { (8, 9), (10, 11), (12, 13) }
# d = { 14.0, 15.0, 16.0 }
# a.concatenate(c) and a.concatenate(d) would result in error.
a.concatenate(b) == { 1, 2, 3, 4, 5, 6, 7 }
```
Args:
dataset: `Dataset` to be concatenated.
Returns:
Dataset: A `Dataset`.
"""
return ConcatenateDataset(self, dataset)
def prefetch(self, buffer_size):
"""Creates a `Dataset` that prefetches elements from this dataset.
Args:
buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
maximum number elements that will be buffered when prefetching.
Returns:
Dataset: A `Dataset`.
"""
return PrefetchDataset(self, buffer_size)
@staticmethod
def list_files(file_pattern, shuffle=None):
"""A dataset of all files matching a pattern.
Example:
If we had the following files on our filesystem:
- /path/to/dir/a.txt
- /path/to/dir/b.py
- /path/to/dir/c.py
If we pass "/path/to/dir/*.py" as the directory, the dataset would
produce:
- /path/to/dir/b.py
- /path/to/dir/c.py
NOTE: The order of the file names returned can be non-deterministic even
when `shuffle` is `False`.
Args:
file_pattern: A string or scalar string `tf.Tensor`, representing
the filename pattern that will be matched.
shuffle: (Optional.) If `True`, the file names will be shuffled randomly.
Defaults to `True`.
Returns:
Dataset: A `Dataset` of strings corresponding to file names.
"""
# TODO(b/73959787): Add a `seed` argument and make the `shuffle=False`
# behavior deterministic (e.g. by sorting the filenames).
if shuffle is None:
shuffle = True
matching_files = gen_io_ops.matching_files(file_pattern)
dataset = Dataset.from_tensor_slices(matching_files)
if shuffle:
# NOTE(mrry): The shuffle buffer size must be greater than zero, but the
# list of files might be empty.
buffer_size = math_ops.maximum(
array_ops.shape(matching_files, out_type=dtypes.int64)[0], 1)
dataset = dataset.shuffle(buffer_size)
return dataset
def repeat(self, count=None):
"""Repeats this dataset `count` times.
NOTE: If this dataset is a function of global state (e.g. a random number
generator), then different repetitions may produce different elements.
Args:
count: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
number of times the dataset should be repeated. The default behavior
(if `count` is `None` or `-1`) is for the dataset be repeated
indefinitely.
Returns:
Dataset: A `Dataset`.
"""
return RepeatDataset(self, count)
def _enumerate(self, start=0):
max_value = np.iinfo(dtypes.int64.as_numpy_dtype).max
return Dataset.zip((Dataset.range(start, max_value), self))
def shuffle(self, buffer_size, seed=None, reshuffle_each_iteration=None):
"""Randomly shuffles the elements of this dataset.
Args:
buffer_size: A `tf.int64` scalar `tf.Tensor`, representing the
number of elements from this dataset from which the new
dataset will sample.
seed: (Optional.) A `tf.int64` scalar `tf.Tensor`, representing the
random seed that will be used to create the distribution. See
@{tf.set_random_seed} for behavior.
reshuffle_each_iteration: (Optional.) A boolean, which if true indicates
that the dataset should be pseudorandomly reshuffled each time it is
iterated over. (Defaults to `True`.)
Returns:
Dataset: A `Dataset`.
"""
return ShuffleDataset(self, buffer_size, seed, reshuffle_each_iteration)
def cache(self, filename=""):
"""Caches the elements in this dataset.
Args:
filename: A `tf.string` scalar `tf.Tensor`, representing the name of a
directory on the filesystem to use for caching tensors in this Dataset.
If a filename is not provided, the dataset will be cached in memory.
Returns:
Dataset: A `Dataset`.
"""
return CacheDataset(self, filename)
def take(self, count):
"""Creates a `Dataset` with at most `count` elements from this dataset.
Args:
count: A `tf.int64` scalar `tf.Tensor`, representing the number of
elements of this dataset that should be taken to form the new dataset.
If `count` is -1, or if `count` is greater than the size of this
dataset, the new dataset will contain all elements of this dataset.
Returns:
Dataset: A `Dataset`.
"""
return TakeDataset(self, count)
def skip(self, count):
"""Creates a `Dataset` that skips `count` elements from this dataset.
Args:
count: A `tf.int64` scalar `tf.Tensor`, representing the number
of elements of this dataset that should be skipped to form the
new dataset. If `count` is greater than the size of this
dataset, the new dataset will contain no elements. If `count`
is -1, skips the entire dataset.
Returns:
Dataset: A `Dataset`.
"""
return SkipDataset(self, count)
def shard(self, num_shards, index):
"""Creates a `Dataset` that includes only 1/`num_shards` of this dataset.
This dataset operator is very useful when running distributed training, as
it allows each worker to read a unique subset.
When reading a single input file, you can skip elements as follows:
```python
d = tf.data.TFRecordDataset(FLAGS.input_file)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Important caveats:
- Be sure to shard before you use any randomizing operator (such as
shuffle).
- Generally it is best if the shard operator is used early in the dataset
pipeline. For example, when reading from a set of TFRecord files, shard
before converting the dataset to input samples. This avoids reading every
file on every worker. The following is an example of an efficient
sharding strategy within a complete pipeline:
```python
d = Dataset.list_files(FLAGS.pattern)
d = d.shard(FLAGS.num_workers, FLAGS.worker_index)
d = d.repeat(FLAGS.num_epochs)
d = d.shuffle(FLAGS.shuffle_buffer_size)
d = d.repeat()
d = d.interleave(tf.data.TFRecordDataset,
cycle_length=FLAGS.num_readers, block_length=1)
d = d.map(parser_fn, num_parallel_calls=FLAGS.num_map_threads)
```
Args:
num_shards: A `tf.int64` scalar `tf.Tensor`, representing the number of
shards operating in parallel.
index: A `tf.int64` scalar `tf.Tensor`, representing the worker index.
Returns:
Dataset: A `Dataset`.
Raises:
ValueError: if `num_shards` or `index` are illegal values. Note: error
checking is done on a best-effort basis, and aren't guaranteed to be
caught upon dataset creation. (e.g. providing in a placeholder tensor
bypasses the early checking, and will instead result in an error during
a session.run call.)
"""
num_shards = ops.convert_to_tensor(
num_shards, name="num_shards", dtype=dtypes.int64)
num_shards_static = tensor_util.constant_value(num_shards)
index = ops.convert_to_tensor(index, name="index", dtype=dtypes.int64)
index_static = tensor_util.constant_value(index)
if num_shards_static is not None and num_shards_static < 1:
raise ValueError("num_shards must be >= 1; got: %s" % num_shards_static)
if index_static is not None and index_static < 0:
raise ValueError("index must be >= 0; got: %s" % index_static)
if (index_static is not None and num_shards_static is not None and
index_static >= num_shards_static):
raise ValueError("index must be <= num_shards; %s is not < %s" %
(index_static, num_shards_static))
def filter_fn(elem_index, _):
mod_result = math_ops.mod(elem_index, num_shards)
return math_ops.equal(mod_result, index)
return self._enumerate().filter(filter_fn).map(lambda _, elem: elem)
def batch(self, batch_size):
"""Combines consecutive elements of this dataset into batches.
NOTE: If the number of elements (`N`) in this dataset is not an exact
multiple of `batch_size`, the final batch contain smaller tensors with
shape `N % batch_size` in the batch dimension. If your program depends on
the batches having the same shape, consider using the
@{tf.contrib.data.batch_and_drop_remainder} transformation instead.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
Returns:
Dataset: A `Dataset`.
"""
return BatchDataset(self, batch_size)
def padded_batch(self, batch_size, padded_shapes, padding_values=None):
"""Combines consecutive elements of this dataset into padded batches.
This transformation combines multiple consecutive elements of the input
dataset into a single element. Like @{tf.data.Dataset.batch}, the tensors
in the resulting element have an additional outer dimension, which will be
`batch_size` for all but the last element, and `N % batch_size` for the
last element (where `N` is the number of elements in this dataset). Unlike
@{tf.data.Dataset.batch}, the elements may have different shapes for some
of their components, and this transformation will pad each component to
the respective shape in `padding_shapes`. The `padding_shapes` argument
determines the resulting shape for each dimension of each component in an
output element:
* If the dimension is a constant (e.g. `tf.Dimension(37)`), the component
will be padded out to that length in that dimension.
* If the dimension is unknown (e.g. `tf.Dimension(None)`), the component
will be padded out to the maximum length of all elements in that
dimension.
NOTE: If the number of elements (`N`) in this dataset is not an exact
multiple of `batch_size`, the final batch contain smaller tensors with
shape `N % batch_size` in the batch dimension. If your program depends on
the batches having the same shape, consider using the
@{tf.contrib.data.padded_batch_and_drop_remainder} transformation instead.
See also @{tf.contrib.data.dense_to_sparse_batch}, which combines elements
that may have different shapes into a @{tf.SparseTensor}.
Args:
batch_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
consecutive elements of this dataset to combine in a single batch.
padded_shapes: A nested structure of `tf.TensorShape` or
`tf.int64` vector tensor-like objects representing the shape
to which the respective component of each input element should
be padded prior to batching. Any unknown dimensions
(e.g. `tf.Dimension(None)` in a `tf.TensorShape` or `-1` in a
tensor-like object) will be padded to the maximum size of that
dimension in each batch.
padding_values: (Optional.) A nested structure of scalar-shaped
`tf.Tensor`, representing the padding values to use for the
respective components. Defaults are `0` for numeric types and
the empty string for string types.
Returns:
Dataset: A `Dataset`.
"""
return PaddedBatchDataset(self, batch_size, padded_shapes, padding_values)
def map(self, map_func, num_parallel_calls=None):
"""Maps `map_func` across this dataset.
Args:
map_func: A function mapping a nested structure of tensors (having
shapes and types defined by `self.output_shapes` and
`self.output_types`) to another nested structure of tensors.
num_parallel_calls: (Optional.) A `tf.int32` scalar `tf.Tensor`,
representing the number elements to process in parallel. If not
specified, elements will be processed sequentially.
Returns:
Dataset: A `Dataset`.
"""
if num_parallel_calls is None:
return MapDataset(self, map_func)
else:
return ParallelMapDataset(self, map_func, num_parallel_calls)
def flat_map(self, map_func):
"""Maps `map_func` across this dataset and flattens the result.
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
`Dataset`.
Returns:
Dataset: A `Dataset`.
"""
return FlatMapDataset(self, map_func)
def interleave(self, map_func, cycle_length, block_length=1):
"""Maps `map_func` across this dataset, and interleaves the results.
For example, you can use `Dataset.interleave()` to process many input files
concurrently:
```python
# Preprocess 4 files concurrently, and interleave blocks of 16 records from
# each file.
filenames = ["/var/data/file1.txt", "/var/data/file2.txt", ...]
dataset = (Dataset.from_tensor_slices(filenames)
.interleave(lambda x:
TextLineDataset(x).map(parse_fn, num_parallel_calls=1),
cycle_length=4, block_length=16))
```
The `cycle_length` and `block_length` arguments control the order in which
elements are produced. `cycle_length` controls the number of input elements
that are processed concurrently. If you set `cycle_length` to 1, this
transformation will handle one input element at a time, and will produce
identical results = to @{tf.data.Dataset.flat_map}. In general,
this transformation will apply `map_func` to `cycle_length` input elements,
open iterators on the returned `Dataset` objects, and cycle through them
producing `block_length` consecutive elements from each iterator, and
consuming the next input element each time it reaches the end of an
iterator.
For example:
```python
# NOTE: The following examples use `{ ... }` to represent the
# contents of a dataset.
a = { 1, 2, 3, 4, 5 }
# NOTE: New lines indicate "block" boundaries.
a.interleave(lambda x: Dataset.from_tensors(x).repeat(6),
cycle_length=2, block_length=4) == {
1, 1, 1, 1,
2, 2, 2, 2,
1, 1,
2, 2,
3, 3, 3, 3,
4, 4, 4, 4,
3, 3,
4, 4,
5, 5, 5, 5,
5, 5,
}
```
NOTE: The order of elements yielded by this transformation is
deterministic, as long as `map_func` is a pure function. If
`map_func` contains any stateful operations, the order in which
that state is accessed is undefined.
Args:
map_func: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
`Dataset`.
cycle_length: The number of elements from this dataset that will be
processed concurrently.
block_length: The number of consecutive elements to produce from each
input element before cycling to another input element.
Returns:
Dataset: A `Dataset`.
"""
return InterleaveDataset(self, map_func, cycle_length, block_length)
def filter(self, predicate):
"""Filters this dataset according to `predicate`.
Args:
predicate: A function mapping a nested structure of tensors (having shapes
and types defined by `self.output_shapes` and `self.output_types`) to a
scalar `tf.bool` tensor.
Returns:
Dataset: A `Dataset`.
"""
return FilterDataset(self, predicate)
def apply(self, transformation_func):
"""Apply a transformation function to this dataset.
`apply` enables chaining of custom `Dataset` transformations, which are
represented as functions that take one `Dataset` argument and return a
transformed `Dataset`.
For example:
```
dataset = (dataset.map(lambda x: x ** 2)
.apply(group_by_window(key_func, reduce_func, window_size))
.map(lambda x: x ** 3))
```
Args:
transformation_func: A function that takes one `Dataset` argument and
returns a `Dataset`.
Returns:
Dataset: The `Dataset` returned by applying `transformation_func` to this
dataset.
"""
dataset = transformation_func(self)
if not isinstance(dataset, Dataset):
raise TypeError("`transformation_func` must return a Dataset.")
return dataset
class TensorDataset(Dataset):
"""A `Dataset` with a single element, viz. a nested structure of tensors."""
def __init__(self, tensors):
"""See `Dataset.from_tensors()` for details."""
super(TensorDataset, self).__init__()
with ops.name_scope("tensors"):
tensors = nest.pack_sequence_as(tensors, [
sparse_tensor_lib.SparseTensor.from_value(t)
if sparse_tensor_lib.is_sparse(t) else ops.convert_to_tensor(
t, name="component_%d" % i)
for i, t in enumerate(nest.flatten(tensors))
])
self._tensors = sparse.serialize_sparse_tensors(tensors)
self._output_classes = sparse.get_classes(tensors)
self._output_shapes = nest.pack_sequence_as(
tensors, [t.get_shape() for t in nest.flatten(tensors)])
self._output_types = nest.pack_sequence_as(
tensors, [t.dtype for t in nest.flatten(tensors)])
def _as_variant_tensor(self):
return gen_dataset_ops.tensor_dataset(
nest.flatten(self._tensors),
output_shapes=nest.flatten(
sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
@property
def output_classes(self):
return self._output_classes
@property
def output_shapes(self):
return self._output_shapes