/
data_flow_ops.py
1661 lines (1351 loc) · 62.5 KB
/
data_flow_ops.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#==============================================================================
"""Data Flow Operations."""
# pylint: disable=g-bad-name
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import hashlib
import threading
import six
from tensorflow.python.framework import dtypes as _dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_data_flow_ops
from tensorflow.python.ops import math_ops
# go/tf-wildcard-import
# pylint: disable=wildcard-import
from tensorflow.python.ops.gen_data_flow_ops import *
# pylint: enable=wildcard-import
def _as_type_list(dtypes):
"""Convert dtypes to a list of types."""
assert dtypes is not None
if not (isinstance(dtypes, list) or isinstance(dtypes, tuple)):
# We have a single type.
return [dtypes]
else:
# We have a list or tuple of types.
return list(dtypes)
def _as_shape_list(shapes, dtypes, unknown_dim_allowed=False,
unknown_rank_allowed=False):
"""Convert shapes to a list of tuples of int (or None)."""
del dtypes
if unknown_dim_allowed:
if (not isinstance(shapes, collections.Sequence)
or not shapes
or any(shape is None or isinstance(shape, int) for shape in shapes)):
raise ValueError(
"When providing partial shapes, a list of shapes must be provided.")
if shapes is None: return None
if isinstance(shapes, tensor_shape.TensorShape):
shapes = [shapes]
if not isinstance(shapes, (tuple, list)):
raise TypeError(
"shapes must be a TensorShape or a list or tuple of TensorShapes.")
if all(shape is None or isinstance(shape, int) for shape in shapes):
# We have a single shape.
shapes = [shapes]
shapes = [tensor_shape.as_shape(shape) for shape in shapes]
if not unknown_dim_allowed:
if any([not shape.is_fully_defined() for shape in shapes]):
raise ValueError("All shapes must be fully defined: %s" % shapes)
if not unknown_rank_allowed:
if any([shape.dims is None for shape in shapes]):
raise ValueError("All shapes must have a defined rank: %s" % shapes)
return shapes
def _as_name_list(names, dtypes):
if names is None:
return None
if not isinstance(names, (list, tuple)):
names = [names]
if len(names) != len(dtypes):
raise ValueError("List of names must have the same length as the list "
"of dtypes")
return list(names)
def _shape_common(s1, s2):
"""The greatest lower bound (ordered by specificity) TensorShape."""
s1 = tensor_shape.TensorShape(s1)
s2 = tensor_shape.TensorShape(s2)
if s1.ndims is None or s2.ndims is None or s1.ndims != s2.ndims:
return tensor_shape.unknown_shape()
d = [
d1 if d1 is not None and d1 == d2 else None
for (d1, d2) in zip(s1.as_list(), s2.as_list())]
return tensor_shape.TensorShape(d)
# pylint: disable=protected-access
class QueueBase(object):
"""Base class for queue implementations.
A queue is a TensorFlow data structure that stores tensors across
multiple steps, and exposes operations that enqueue and dequeue
tensors.
Each queue element is a tuple of one or more tensors, where each
tuple component has a static dtype, and may have a static shape. The
queue implementations support versions of enqueue and dequeue that
handle single elements, versions that support enqueuing and
dequeuing a batch of elements at once.
See @{tf.FIFOQueue} and
@{tf.RandomShuffleQueue} for concrete
implementations of this class, and instructions on how to create
them.
"""
def __init__(self, dtypes, shapes, names, queue_ref):
"""Constructs a queue object from a queue reference.
The two optional lists, `shapes` and `names`, must be of the same length
as `dtypes` if provided. The values at a given index `i` indicate the
shape and name to use for the corresponding queue component in `dtypes`.
Args:
dtypes: A list of types. The length of dtypes must equal the number
of tensors in each element.
shapes: Constraints on the shapes of tensors in an element:
A list of shape tuples or None. This list is the same length
as dtypes. If the shape of any tensors in the element are constrained,
all must be; shapes can be None if the shapes should not be constrained.
names: Optional list of names. If provided, the `enqueue()` and
`dequeue()` methods will use dictionaries with these names as keys.
Must be None or a list or tuple of the same length as `dtypes`.
queue_ref: The queue reference, i.e. the output of the queue op.
Raises:
ValueError: If one of the arguments is invalid.
"""
self._dtypes = dtypes
if shapes is not None:
if len(shapes) != len(dtypes):
raise ValueError("Queue shapes must have the same length as dtypes")
self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
else:
self._shapes = [tensor_shape.unknown_shape() for _ in self._dtypes]
if names is not None:
if len(names) != len(dtypes):
raise ValueError("Queue names must have the same length as dtypes")
self._names = names
else:
self._names = None
self._queue_ref = queue_ref
self._name = self._queue_ref.op.name.split("/")[-1]
@staticmethod
def from_list(index, queues):
"""Create a queue using the queue reference from `queues[index]`.
Args:
index: An integer scalar tensor that determines the input that gets
selected.
queues: A list of `QueueBase` objects.
Returns:
A `QueueBase` object.
Raises:
TypeError: When `queues` is not a list of `QueueBase` objects,
or when the data types of `queues` are not all the same.
"""
if ((not queues) or
(not isinstance(queues, list)) or
(not all(isinstance(x, QueueBase) for x in queues))):
raise TypeError("A list of queues expected")
dtypes = queues[0].dtypes
if not all([dtypes == q.dtypes for q in queues[1:]]):
raise TypeError("Queues do not have matching component dtypes.")
names = queues[0].names
if not all([names == q.names for q in queues[1:]]):
raise TypeError("Queues do not have matching component names.")
queue_shapes = [q.shapes for q in queues]
reduced_shapes = [
six.moves.reduce(_shape_common, s) for s in zip(*queue_shapes)]
queue_refs = array_ops.stack([x.queue_ref for x in queues])
selected_queue = array_ops.gather(queue_refs, index)
return QueueBase(dtypes=dtypes, shapes=reduced_shapes, names=names,
queue_ref=selected_queue)
@property
def queue_ref(self):
"""The underlying queue reference."""
return self._queue_ref
@property
def name(self):
"""The name of the underlying queue."""
return self._queue_ref.op.name
@property
def dtypes(self):
"""The list of dtypes for each component of a queue element."""
return self._dtypes
@property
def shapes(self):
"""The list of shapes for each component of a queue element."""
return self._shapes
@property
def names(self):
"""The list of names for each component of a queue element."""
return self._names
def _check_enqueue_dtypes(self, vals):
"""Validate and convert `vals` to a list of `Tensor`s.
The `vals` argument can be a Tensor, a list or tuple of tensors, or a
dictionary with tensor values.
If it is a dictionary, the queue must have been constructed with a
`names` attribute and the dictionary keys must match the queue names.
If the queue was constructed with a `names` attribute, `vals` must
be a dictionary.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary..
Returns:
A list of `Tensor` objects.
Raises:
ValueError: If `vals` is invalid.
"""
if isinstance(vals, dict):
if not self._names:
raise ValueError("Queue must have names to enqueue a dictionary")
if sorted(self._names) != sorted(vals.keys()):
raise ValueError("Keys in dictionary to enqueue do not match "
"names of Queue. Dictionary: (%s), Queue: (%s)" %
(sorted(vals.keys()), sorted(self._names)))
# The order of values in `self._names` indicates the order in which the
# tensors in the dictionary `vals` must be listed.
vals = [vals[k] for k in self._names]
else:
if self._names:
raise ValueError("You must enqueue a dictionary in a Queue with names")
if not isinstance(vals, (list, tuple)):
vals = [vals]
tensors = []
for i, (val, dtype) in enumerate(zip(vals, self._dtypes)):
tensors.append(ops.convert_to_tensor(val, dtype=dtype,
name="component_%d" % i))
return tensors
def _scope_vals(self, vals):
"""Return a list of values to pass to `name_scope()`.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary.
Returns:
The values in vals as a list.
"""
if isinstance(vals, (list, tuple)):
return vals
elif isinstance(vals, dict):
return vals.values()
else:
return [vals]
def enqueue(self, vals, name=None):
"""Enqueues one element to this queue.
If the queue is full when this operation executes, it will block
until the element has been enqueued.
At runtime, this operation may raise an error if the queue is
@{tf.QueueBase.close} before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.CancelledError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
@{tf.Session.close},
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary containing
the values to enqueue.
name: A name for the operation (optional).
Returns:
The operation that enqueues a new tuple of tensors to the queue.
"""
with ops.name_scope(name, "%s_enqueue" % self._name,
self._scope_vals(vals)) as scope:
vals = self._check_enqueue_dtypes(vals)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
for val, shape in zip(vals, self._shapes):
val.get_shape().assert_is_compatible_with(shape)
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_enqueue_v2(
self._queue_ref, vals, name=scope)
else:
return gen_data_flow_ops._queue_enqueue(
self._queue_ref, vals, name=scope)
def enqueue_many(self, vals, name=None):
"""Enqueues zero or more elements to this queue.
This operation slices each component tensor along the 0th dimension to
make multiple queue elements. All of the tensors in `vals` must have the
same size in the 0th dimension.
If the queue is full when this operation executes, it will block
until all of the elements have been enqueued.
At runtime, this operation may raise an error if the queue is
@{tf.QueueBase.close} before or during its execution. If the
queue is closed before this operation runs,
`tf.errors.CancelledError` will be raised. If this operation is
blocked, and either (i) the queue is closed by a close operation
with `cancel_pending_enqueues=True`, or (ii) the session is
@{tf.Session.close},
`tf.errors.CancelledError` will be raised.
Args:
vals: A tensor, a list or tuple of tensors, or a dictionary
from which the queue elements are taken.
name: A name for the operation (optional).
Returns:
The operation that enqueues a batch of tuples of tensors to the queue.
"""
with ops.name_scope(name, "%s_EnqueueMany" % self._name,
self._scope_vals(vals)) as scope:
vals = self._check_enqueue_dtypes(vals)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
batch_dim = vals[0].get_shape().with_rank_at_least(1)[0]
for val, shape in zip(vals, self._shapes):
batch_dim = batch_dim.merge_with(
val.get_shape().with_rank_at_least(1)[0])
val.get_shape()[1:].assert_is_compatible_with(shape)
return gen_data_flow_ops._queue_enqueue_many_v2(
self._queue_ref, vals, name=scope)
def _dequeue_return_value(self, tensors):
"""Return the value to return from a dequeue op.
If the queue has names, return a dictionary with the
names as keys. Otherwise return either a single tensor
or a list of tensors depending on the length of `tensors`.
Args:
tensors: List of tensors from the dequeue op.
Returns:
A single tensor, a list of tensors, or a dictionary
of tensors.
"""
if self._names:
# The returned values in `tensors` are in the same order as
# the names in `self._names`.
return {n: tensors[i] for i, n in enumerate(self._names)}
elif len(tensors) == 1:
return tensors[0]
else:
return tensors
def dequeue(self, name=None):
"""Dequeues one element from this queue.
If the queue is empty when this operation executes, it will block
until there is an element to dequeue.
At runtime, this operation may raise an error if the queue is
@{tf.QueueBase.close} before or during its execution. If the
queue is closed, the queue is empty, and there are no pending
enqueue operations that can fulfill this request,
`tf.errors.OutOfRangeError` will be raised. If the session is
@{tf.Session.close},
`tf.errors.CancelledError` will be raised.
Args:
name: A name for the operation (optional).
Returns:
The tuple of tensors that was dequeued.
"""
if name is None:
name = "%s_Dequeue" % self._name
if self._queue_ref.dtype == _dtypes.resource:
ret = gen_data_flow_ops._queue_dequeue_v2(
self._queue_ref, self._dtypes, name=name)
else:
ret = gen_data_flow_ops._queue_dequeue(
self._queue_ref, self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the `QueueBase` object.
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(shape)
return self._dequeue_return_value(ret)
def dequeue_many(self, n, name=None):
"""Dequeues and concatenates `n` elements from this queue.
This operation concatenates queue-element component tensors along
the 0th dimension to make a single component tensor. All of the
components in the dequeued tuple will have size `n` in the 0th dimension.
If the queue is closed and there are less than `n` elements left, then an
`OutOfRange` exception is raised.
At runtime, this operation may raise an error if the queue is
@{tf.QueueBase.close} before or during its execution. If the
queue is closed, the queue contains fewer than `n` elements, and
there are no pending enqueue operations that can fulfill this
request, `tf.errors.OutOfRangeError` will be raised. If the
session is @{tf.Session.close},
`tf.errors.CancelledError` will be raised.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
name: A name for the operation (optional).
Returns:
The tuple of concatenated tensors that was dequeued.
"""
if name is None:
name = "%s_DequeueMany" % self._name
ret = gen_data_flow_ops._queue_dequeue_many_v2(
self._queue_ref, n=n, component_types=self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
op = ret[0].op
batch_dim = tensor_shape.Dimension(tensor_util.constant_value(op.inputs[1]))
for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
return self._dequeue_return_value(ret)
def dequeue_up_to(self, n, name=None):
"""Dequeues and concatenates `n` elements from this queue.
**Note** This operation is not supported by all queues. If a queue does not
support DequeueUpTo, then a `tf.errors.UnimplementedError` is raised.
This operation concatenates queue-element component tensors along
the 0th dimension to make a single component tensor. If the queue
has not been closed, all of the components in the dequeued tuple
will have size `n` in the 0th dimension.
If the queue is closed and there are more than `0` but fewer than
`n` elements remaining, then instead of raising a
`tf.errors.OutOfRangeError` like @{tf.QueueBase.dequeue_many},
less than `n` elements are returned immediately. If the queue is
closed and there are `0` elements left in the queue, then a
`tf.errors.OutOfRangeError` is raised just like in `dequeue_many`.
Otherwise the behavior is identical to `dequeue_many`.
Args:
n: A scalar `Tensor` containing the number of elements to dequeue.
name: A name for the operation (optional).
Returns:
The tuple of concatenated tensors that was dequeued.
"""
if name is None:
name = "%s_DequeueUpTo" % self._name
ret = gen_data_flow_ops._queue_dequeue_up_to_v2(
self._queue_ref, n=n, component_types=self._dtypes, name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Queue object.
op = ret[0].op
for output, shape in zip(op.values(), self._shapes):
output.set_shape(tensor_shape.TensorShape([None]).concatenate(shape))
return self._dequeue_return_value(ret)
def close(self, cancel_pending_enqueues=False, name=None):
"""Closes this queue.
This operation signals that no more elements will be enqueued in
the given queue. Subsequent `enqueue` and `enqueue_many`
operations will fail. Subsequent `dequeue` and `dequeue_many`
operations will continue to succeed if sufficient elements remain
in the queue. Subsequent `dequeue` and `dequeue_many` operations
that would block will fail immediately.
If `cancel_pending_enqueues` is `True`, all pending requests will also
be cancelled.
Args:
cancel_pending_enqueues: (Optional.) A boolean, defaulting to
`False` (described above).
name: A name for the operation (optional).
Returns:
The operation that closes the queue.
"""
if name is None:
name = "%s_Close" % self._name
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_close_v2(
self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
else:
return gen_data_flow_ops._queue_close(
self._queue_ref, cancel_pending_enqueues=cancel_pending_enqueues,
name=name)
def size(self, name=None):
"""Compute the number of elements in this queue.
Args:
name: A name for the operation (optional).
Returns:
A scalar tensor containing the number of elements in this queue.
"""
if name is None:
name = "%s_Size" % self._name
if self._queue_ref.dtype == _dtypes.resource:
return gen_data_flow_ops._queue_size_v2(self._queue_ref, name=name)
else:
return gen_data_flow_ops._queue_size(self._queue_ref, name=name)
class RandomShuffleQueue(QueueBase):
"""A queue implementation that dequeues elements in a random order.
See @{tf.QueueBase} for a description of the methods on
this class.
"""
def __init__(self, capacity, min_after_dequeue, dtypes, shapes=None,
names=None, seed=None, shared_name=None,
name="random_shuffle_queue"):
"""Create a queue that dequeues elements in a random order.
A `RandomShuffleQueue` has bounded capacity; supports multiple
concurrent producers and consumers; and provides exactly-once
delivery.
A `RandomShuffleQueue` holds a list of up to `capacity`
elements. Each element is a fixed-length tuple of tensors whose
dtypes are described by `dtypes`, and whose shapes are optionally
described by the `shapes` argument.
If the `shapes` argument is specified, each component of a queue
element must have the respective fixed shape. If it is
unspecified, different queue elements may have different shapes,
but the use of `dequeue_many` is disallowed.
The `min_after_dequeue` argument allows the caller to specify a
minimum number of elements that will remain in the queue after a
`dequeue` or `dequeue_many` operation completes, to ensure a
minimum level of mixing of elements. This invariant is maintained
by blocking those operations until sufficient elements have been
enqueued. The `min_after_dequeue` argument is ignored after the
queue has been closed.
Args:
capacity: An integer. The upper bound on the number of elements
that may be stored in this queue.
min_after_dequeue: An integer (described above).
dtypes: A list of `DType` objects. The length of `dtypes` must equal
the number of tensors in each queue element.
shapes: (Optional.) A list of fully-defined `TensorShape` objects
with the same length as `dtypes`, or `None`.
names: (Optional.) A list of string naming the components in the queue
with the same length as `dtypes`, or `None`. If specified the dequeue
methods return a dictionary with the names as keys.
seed: A Python integer. Used to create a random seed. See
@{tf.set_random_seed}
for behavior.
shared_name: (Optional.) If non-empty, this queue will be shared under
the given name across multiple sessions.
name: Optional name for the queue operation.
"""
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
seed1, seed2 = random_seed.get_seed(seed)
if seed1 is None and seed2 is None:
seed1, seed2 = 0, 0
elif seed is None and shared_name is not None:
# This means that graph seed is provided but op seed is not provided.
# If shared_name is also provided, make seed2 depend only on the graph
# seed and shared_name. (seed2 from get_seed() is generally dependent on
# the id of the last op created.)
string = (str(seed1) + shared_name).encode("utf-8")
seed2 = int(hashlib.md5(string).hexdigest()[:8], 16) & 0x7FFFFFFF
queue_ref = gen_data_flow_ops._random_shuffle_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
min_after_dequeue=min_after_dequeue, seed=seed1, seed2=seed2,
shared_name=shared_name, name=name)
super(RandomShuffleQueue, self).__init__(dtypes, shapes, names, queue_ref)
class FIFOQueue(QueueBase):
"""A queue implementation that dequeues elements in first-in first-out order.
See @{tf.QueueBase} for a description of the methods on
this class.
"""
def __init__(self, capacity, dtypes, shapes=None, names=None,
shared_name=None, name="fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
A `FIFOQueue` has bounded capacity; supports multiple concurrent
producers and consumers; and provides exactly-once delivery.
A `FIFOQueue` holds a list of up to `capacity` elements. Each
element is a fixed-length tuple of tensors whose dtypes are
described by `dtypes`, and whose shapes are optionally described
by the `shapes` argument.
If the `shapes` argument is specified, each component of a queue
element must have the respective fixed shape. If it is
unspecified, different queue elements may have different shapes,
but the use of `dequeue_many` is disallowed.
Args:
capacity: An integer. The upper bound on the number of elements
that may be stored in this queue.
dtypes: A list of `DType` objects. The length of `dtypes` must equal
the number of tensors in each queue element.
shapes: (Optional.) A list of fully-defined `TensorShape` objects
with the same length as `dtypes`, or `None`.
names: (Optional.) A list of string naming the components in the queue
with the same length as `dtypes`, or `None`. If specified the dequeue
methods return a dictionary with the names as keys.
shared_name: (Optional.) If non-empty, this queue will be shared under
the given name across multiple sessions.
name: Optional name for the queue operation.
"""
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes)
names = _as_name_list(names, dtypes)
queue_ref = gen_data_flow_ops._fifo_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
super(FIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
class PaddingFIFOQueue(QueueBase):
"""A FIFOQueue that supports batching variable-sized tensors by padding.
A `PaddingFIFOQueue` may contain components with dynamic shape, while also
supporting `dequeue_many`. See the constructor for more details.
See @{tf.QueueBase} for a description of the methods on
this class.
"""
def __init__(self, capacity, dtypes, shapes, names=None, shared_name=None,
name="padding_fifo_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
A `PaddingFIFOQueue` has bounded capacity; supports multiple concurrent
producers and consumers; and provides exactly-once delivery.
A `PaddingFIFOQueue` holds a list of up to `capacity` elements. Each
element is a fixed-length tuple of tensors whose dtypes are
described by `dtypes`, and whose shapes are described by the `shapes`
argument.
The `shapes` argument must be specified; each component of a queue
element must have the respective shape. Shapes of fixed
rank but variable size are allowed by setting any shape dimension to None.
In this case, the inputs' shape may vary along the given dimension, and
`dequeue_many` will pad the given dimension with zeros up to the maximum
shape of all elements in the given batch.
Args:
capacity: An integer. The upper bound on the number of elements
that may be stored in this queue.
dtypes: A list of `DType` objects. The length of `dtypes` must equal
the number of tensors in each queue element.
shapes: A list of `TensorShape` objects, with the same length as
`dtypes`. Any dimension in the `TensorShape` containing value
`None` is dynamic and allows values to be enqueued with
variable size in that dimension.
names: (Optional.) A list of string naming the components in the queue
with the same length as `dtypes`, or `None`. If specified the dequeue
methods return a dictionary with the names as keys.
shared_name: (Optional.) If non-empty, this queue will be shared under
the given name across multiple sessions.
name: Optional name for the queue operation.
Raises:
ValueError: If shapes is not a list of shapes, or the lengths of dtypes
and shapes do not match, or if names is specified and the lengths of
dtypes and names do not match.
"""
dtypes = _as_type_list(dtypes)
shapes = _as_shape_list(shapes, dtypes, unknown_dim_allowed=True)
names = _as_name_list(names, dtypes)
if len(dtypes) != len(shapes):
raise ValueError("Shapes must be provided for all components, "
"but received %d dtypes and %d shapes."
% (len(dtypes), len(shapes)))
queue_ref = gen_data_flow_ops._padding_fifo_queue_v2(
component_types=dtypes, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
super(PaddingFIFOQueue, self).__init__(dtypes, shapes, names, queue_ref)
class PriorityQueue(QueueBase):
"""A queue implementation that dequeues elements in prioritized order.
See @{tf.QueueBase} for a description of the methods on
this class.
"""
def __init__(self, capacity, types, shapes=None, names=None, shared_name=None,
name="priority_queue"):
"""Creates a queue that dequeues elements in a first-in first-out order.
A `PriorityQueue` has bounded capacity; supports multiple concurrent
producers and consumers; and provides exactly-once delivery.
A `PriorityQueue` holds a list of up to `capacity` elements. Each
element is a fixed-length tuple of tensors whose dtypes are
described by `types`, and whose shapes are optionally described
by the `shapes` argument.
If the `shapes` argument is specified, each component of a queue
element must have the respective fixed shape. If it is
unspecified, different queue elements may have different shapes,
but the use of `dequeue_many` is disallowed.
Enqueues and Dequeues to the `PriorityQueue` must include an additional
tuple entry at the beginning: the `priority`. The priority must be
an int64 scalar (for `enqueue`) or an int64 vector (for `enqueue_many`).
Args:
capacity: An integer. The upper bound on the number of elements
that may be stored in this queue.
types: A list of `DType` objects. The length of `types` must equal
the number of tensors in each queue element, except the first priority
element. The first tensor in each element is the priority,
which must be type int64.
shapes: (Optional.) A list of fully-defined `TensorShape` objects,
with the same length as `types`, or `None`.
names: (Optional.) A list of strings naming the components in the queue
with the same length as `dtypes`, or `None`. If specified, the dequeue
methods return a dictionary with the names as keys.
shared_name: (Optional.) If non-empty, this queue will be shared under
the given name across multiple sessions.
name: Optional name for the queue operation.
"""
types = _as_type_list(types)
shapes = _as_shape_list(shapes, types)
queue_ref = gen_data_flow_ops._priority_queue_v2(
component_types=types, shapes=shapes, capacity=capacity,
shared_name=shared_name, name=name)
priority_dtypes = [_dtypes.int64] + types
priority_shapes = [()] + shapes if shapes else shapes
super(PriorityQueue, self).__init__(
priority_dtypes, priority_shapes, names, queue_ref)
# TODO(josh11b): class BatchQueue(QueueBase):
class Barrier(object):
"""Represents a key-value map that persists across graph executions."""
def __init__(self, types, shapes=None, shared_name=None, name="barrier"):
"""Creates a barrier that persists across different graph executions.
A barrier represents a key-value map, where each key is a string, and
each value is a tuple of tensors.
At runtime, the barrier contains 'complete' and 'incomplete'
elements. A complete element has defined tensors for all
components of its value tuple, and may be accessed using
take_many. An incomplete element has some undefined components in
its value tuple, and may be updated using insert_many.
The barrier call `take_many` outputs values in a particular order.
First, it only outputs completed values. Second, the order in which
completed values are returned matches the order in which their very
first component was inserted into the barrier. So, for example, for this
sequence of insertions and removals:
barrier = Barrier((tf.string, tf.int32), shapes=((), ()))
barrier.insert_many(0, keys=["k1", "k2"], values=["a", "b"]).run()
barrier.insert_many(1, keys=["k1"], values=[1]).run()
barrier.insert_many(0, keys=["k3"], values=["c"]).run()
barrier.insert_many(1, keys=["k3"], values=[3]).run()
barrier.insert_many(1, keys=["k2"], values=[2]).run()
(indices, keys, values) = barrier.take_many(2)
(indices_val, keys_val, values0_val, values1_val) =
session.run([indices, keys, values[0], values[1]])
The output will be (up to permutation of "k1" and "k2"):
indices_val == (-2**63, -2**63)
keys_val == ("k1", "k2")
values0_val == ("a", "b")
values1_val == (1, 2)
Note the key "k2" was inserted into the barrier before "k3". Even though
"k3" was completed first, both are complete by the time
take_many is called. As a result, "k2" is prioritized and "k1" and "k2"
are returned first. "k3" remains in the barrier until the next execution
of `take_many`. Since "k1" and "k2" had their first insertions into
the barrier together, their indices are the same (-2**63). The index
of "k3" will be -2**63 + 1, because it was the next new inserted key.
Args:
types: A single dtype or a tuple of dtypes, corresponding to the
dtypes of the tensor elements that comprise a value in this barrier.
shapes: Optional. Constraints on the shapes of tensors in the values:
a single tensor shape tuple; a tuple of tensor shape tuples
for each barrier-element tuple component; or None if the shape should
not be constrained.
shared_name: Optional. If non-empty, this barrier will be shared under
the given name across multiple sessions.
name: Optional name for the barrier op.
Raises:
ValueError: If one of the `shapes` indicate no elements.
"""
self._types = _as_type_list(types)
if shapes is not None:
shapes = _as_shape_list(shapes, self._types)
self._shapes = [tensor_shape.TensorShape(s) for s in shapes]
for i, shape in enumerate(self._shapes):
if shape.num_elements() == 0:
raise ValueError("Empty tensors are not supported, but received "
"shape '%s' at index %d" % (shape, i))
else:
self._shapes = [tensor_shape.unknown_shape() for _ in self._types]
self._barrier_ref = gen_data_flow_ops._barrier(
component_types=self._types, shapes=self._shapes,
shared_name=shared_name, name=name)
self._name = self._barrier_ref.op.name.split("/")[-1]
@property
def barrier_ref(self):
"""Get the underlying barrier reference."""
return self._barrier_ref
@property
def name(self):
"""The name of the underlying barrier."""
return self._barrier_ref.op.name
def insert_many(self, component_index, keys, values, name=None):
"""For each key, assigns the respective value to the specified component.
This operation updates each element at component_index.
Args:
component_index: The component of the value that is being assigned.
keys: A vector of keys, with length n.
values: An any-dimensional tensor of values, which are associated with the
respective keys. The first dimension must have length n.
name: Optional name for the op.
Returns:
The operation that performs the insertion.
Raises:
InvalidArgumentsError: If inserting keys and values without elements.
"""
if name is None:
name = "%s_BarrierInsertMany" % self._name
return gen_data_flow_ops._barrier_insert_many(
self._barrier_ref, keys, values, component_index, name=name)
def take_many(self,
num_elements,
allow_small_batch=False,
timeout=None,
name=None):
"""Takes the given number of completed elements from this barrier.
This operation concatenates completed-element component tensors along
the 0th dimension to make a single component tensor.
If barrier has no completed elements, this operation will block
until there are 'num_elements' elements to take.
TODO(b/25743580): the semantics of `allow_small_batch` are experimental
and may be extended to other cases in the future.
TODO(ebrevdo): If a take_many(allow_small_batch=True) is blocking
already when the barrier is closed, it will block for ever. Fix this
by using asynchronous operations.
Args:
num_elements: The number of elements to take.
allow_small_batch: If the barrier is closed, don't block if there are less
completed elements than requested, but instead return all available
completed elements.
timeout: This specifies the number of milliseconds to block
before returning with DEADLINE_EXCEEDED. (This option is not
supported yet.)
name: A name for the operation (optional).
Returns:
A tuple of (index, key, value_list).
"index" is a int64 tensor of length num_elements containing the
index of the insert_many call for which the very first component of
the given element was inserted into the Barrier, starting with
the value -2**63. Note, this value is different from the
index of the insert_many call for which the element was completed.
"key" is a string tensor of length num_elements containing the keys.
"value_list" is a tuple of tensors, each one with size num_elements
in the 0th dimension for each component in the barrier's values.
"""
if name is None:
name = "%s_BarrierTakeMany" % self._name
ret = gen_data_flow_ops._barrier_take_many(self._barrier_ref,
num_elements,
self._types,
allow_small_batch,
timeout,
name=name)
# NOTE(mrry): Not using a shape function because we need access to
# the Barrier object.
op = ret[0].op
if allow_small_batch:
batch_dim = None
else:
batch_dim = tensor_shape.Dimension(
tensor_util.constant_value(op.inputs[1]))
op.outputs[0].set_shape(tensor_shape.vector(batch_dim)) # indices
op.outputs[1].set_shape(tensor_shape.vector(batch_dim)) # keys
for output, shape in zip(op.outputs[2:], self._shapes): # value_list
output.set_shape(tensor_shape.TensorShape([batch_dim]).concatenate(shape))
return ret
def close(self, cancel_pending_enqueues=False, name=None):
"""Closes this barrier.
This operation signals that no more new key values will be inserted in the
given barrier. Subsequent InsertMany operations with new keys will fail.
InsertMany operations that just complement already existing keys with other
components, will continue to succeed. Subsequent TakeMany operations will
continue to succeed if sufficient elements remain in the barrier. Subsequent
TakeMany operations that would block will fail immediately.
If `cancel_pending_enqueues` is `True`, all pending requests to the
underlying queue will also be cancelled, and completing of already
started values is also not acceptable anymore.
Args:
cancel_pending_enqueues: (Optional.) A boolean, defaulting to
`False` (described above).
name: Optional name for the op.
Returns:
The operation that closes the barrier.