/
test_util.py
2027 lines (1715 loc) · 71.3 KB
/
test_util.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=invalid-name
"""Test utils for tensorflow."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
from collections import OrderedDict
import contextlib
import gc
import itertools
import math
import random
import re
import tempfile
import threading
import unittest
import numpy as np
import six
_portpicker_import_error = None
try:
import portpicker # pylint: disable=g-import-not-at-top
except ImportError as _error:
_portpicker_import_error = _error
portpicker = None
# pylint: disable=g-import-not-at-top
from google.protobuf import descriptor_pool
from google.protobuf import text_format
from tensorflow.core.framework import graph_pb2
from tensorflow.core.protobuf import config_pb2
from tensorflow.core.protobuf import rewriter_config_pb2
from tensorflow.python import pywrap_tensorflow
from tensorflow.python.client import device_lib
from tensorflow.python.client import session
from tensorflow.python.eager import context
from tensorflow.python.eager import tape # pylint: disable=unused-import
from tensorflow.python.framework import device as pydev
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import errors
from tensorflow.python.framework import errors_impl
from tensorflow.python.framework import importer
from tensorflow.python.framework import ops
from tensorflow.python.framework import random_seed
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import versions
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import variables
from tensorflow.python.platform import googletest
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training import server_lib
from tensorflow.python.util import compat
from tensorflow.python.util import nest
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.protobuf import compare
from tensorflow.python.util.tf_export import tf_export
@tf_export("test.gpu_device_name")
def gpu_device_name():
"""Returns the name of a GPU device if available or the empty string."""
for x in device_lib.list_local_devices():
if x.device_type == "GPU" or x.device_type == "SYCL":
return compat.as_str(x.name)
return ""
def assert_ops_in_graph(expected_ops, graph):
"""Assert all expected operations are found.
Args:
expected_ops: `dict<string, string>` of op name to op type.
graph: Graph to check.
Returns:
`dict<string, node>` of node name to node.
Raises:
ValueError: If the expected ops are not present in the graph.
"""
actual_ops = {}
gd = graph.as_graph_def()
for node in gd.node:
if node.name in expected_ops:
if expected_ops[node.name] != node.op:
raise ValueError("Expected op for node %s is different. %s vs %s" %
(node.name, expected_ops[node.name], node.op))
actual_ops[node.name] = node
if set(expected_ops.keys()) != set(actual_ops.keys()):
raise ValueError("Not all expected ops are present. Expected %s, found %s" %
(expected_ops.keys(), actual_ops.keys()))
return actual_ops
@tf_export("test.assert_equal_graph_def")
def assert_equal_graph_def(actual, expected, checkpoint_v2=False):
"""Asserts that two `GraphDef`s are (mostly) the same.
Compares two `GraphDef` protos for equality, ignoring versions and ordering of
nodes, attrs, and control inputs. Node names are used to match up nodes
between the graphs, so the naming of nodes must be consistent.
Args:
actual: The `GraphDef` we have.
expected: The `GraphDef` we expected.
checkpoint_v2: boolean determining whether to ignore randomized attribute
values that appear in V2 checkpoints.
Raises:
AssertionError: If the `GraphDef`s do not match.
TypeError: If either argument is not a `GraphDef`.
"""
if not isinstance(actual, graph_pb2.GraphDef):
raise TypeError(
"Expected tf.GraphDef for actual, got %s" % type(actual).__name__)
if not isinstance(expected, graph_pb2.GraphDef):
raise TypeError(
"Expected tf.GraphDef for expected, got %s" % type(expected).__name__)
if checkpoint_v2:
_strip_checkpoint_v2_randomized(actual)
_strip_checkpoint_v2_randomized(expected)
diff = pywrap_tensorflow.EqualGraphDefWrapper(actual.SerializeToString(),
expected.SerializeToString())
if diff:
raise AssertionError(compat.as_str(diff))
def assert_meta_graph_protos_equal(tester, a, b):
"""Compares MetaGraphDefs `a` and `b` in unit test class `tester`."""
# Carefully check the collection_defs
tester.assertEqual(set(a.collection_def), set(b.collection_def))
collection_keys = a.collection_def.keys()
for k in collection_keys:
a_value = a.collection_def[k]
b_value = b.collection_def[k]
proto_type = ops.get_collection_proto_type(k)
if proto_type:
a_proto = proto_type()
b_proto = proto_type()
# Number of entries in the collections is the same
tester.assertEqual(
len(a_value.bytes_list.value), len(b_value.bytes_list.value))
for (a_value_item, b_value_item) in zip(a_value.bytes_list.value,
b_value.bytes_list.value):
a_proto.ParseFromString(a_value_item)
b_proto.ParseFromString(b_value_item)
tester.assertProtoEquals(a_proto, b_proto)
else:
tester.assertEquals(a_value, b_value)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("collection_def")
b.ClearField("collection_def")
# Check the graph_defs.
assert_equal_graph_def(a.graph_def, b.graph_def, checkpoint_v2=True)
# Check graph_def versions (ignored by assert_equal_graph_def).
tester.assertProtoEquals(a.graph_def.versions, b.graph_def.versions)
# Compared the fields directly, remove their raw values from the
# proto comparison below.
a.ClearField("graph_def")
b.ClearField("graph_def")
tester.assertProtoEquals(a, b)
# Matches attributes named via _SHARDED_SUFFIX in
# tensorflow/python/training/saver.py
_SHARDED_SAVE_OP_PATTERN = "_temp_[0-9a-z]{32}/part"
def _strip_checkpoint_v2_randomized(graph_def):
for node in graph_def.node:
delete_keys = []
for attr_key in node.attr:
attr_tensor_value = node.attr[attr_key].tensor
if attr_tensor_value and len(attr_tensor_value.string_val) == 1:
attr_tensor_string_value = attr_tensor_value.string_val[0]
if (attr_tensor_string_value and
re.match(_SHARDED_SAVE_OP_PATTERN, str(attr_tensor_string_value))):
delete_keys.append(attr_key)
for attr_key in delete_keys:
del node.attr[attr_key]
def IsGoogleCudaEnabled():
return pywrap_tensorflow.IsGoogleCudaEnabled()
def CudaSupportsHalfMatMulAndConv():
return pywrap_tensorflow.CudaSupportsHalfMatMulAndConv()
def IsMklEnabled():
return pywrap_tensorflow.IsMklEnabled()
def InstallStackTraceHandler():
pywrap_tensorflow.InstallStacktraceHandler()
def NHWCToNCHW(input_tensor):
"""Converts the input from the NHWC format to NCHW.
Args:
input_tensor: a 4- or 5-D tensor, or an array representing shape
Returns:
converted tensor or shape array
"""
# tensor dim -> new axis order
new_axes = {4: [0, 3, 1, 2], 5: [0, 4, 1, 2, 3]}
if isinstance(input_tensor, ops.Tensor):
ndims = input_tensor.shape.ndims
return array_ops.transpose(input_tensor, new_axes[ndims])
else:
ndims = len(input_tensor)
return [input_tensor[a] for a in new_axes[ndims]]
def NHWCToNCHW_VECT_C(input_shape_or_tensor):
"""Transforms the input from the NHWC layout to NCHW_VECT_C layout.
Note: Does not include quantization or type conversion steps, which should
be applied afterwards.
Args:
input_shape_or_tensor: a 4- or 5-D tensor, or an array representing shape
Returns:
tensor or shape array transformed into NCHW_VECT_C
Raises:
ValueError: if last dimension of `input_shape_or_tensor` is not evenly
divisible by 4.
"""
permutations = {5: [0, 3, 1, 2, 4], 6: [0, 4, 1, 2, 3, 5]}
is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
temp_shape = (
input_shape_or_tensor.shape.as_list()
if is_tensor else input_shape_or_tensor)
if temp_shape[-1] % 4 != 0:
raise ValueError(
"Last dimension of input must be evenly divisible by 4 to convert to "
"NCHW_VECT_C.")
temp_shape[-1] //= 4
temp_shape.append(4)
permutation = permutations[len(temp_shape)]
if is_tensor:
t = array_ops.reshape(input_shape_or_tensor, temp_shape)
return array_ops.transpose(t, permutation)
else:
return [temp_shape[a] for a in permutation]
def NCHW_VECT_CToNHWC(input_shape_or_tensor):
"""Transforms the input from the NCHW_VECT_C layout to NHWC layout.
Note: Does not include de-quantization or type conversion steps, which should
be applied beforehand.
Args:
input_shape_or_tensor: a 5- or 6-D tensor, or an array representing shape
Returns:
tensor or shape array transformed into NHWC
Raises:
ValueError: if last dimension of `input_shape_or_tensor` is not 4.
"""
permutations = {5: [0, 2, 3, 1, 4], 6: [0, 2, 3, 4, 1, 5]}
is_tensor = isinstance(input_shape_or_tensor, ops.Tensor)
input_shape = (
input_shape_or_tensor.shape.as_list()
if is_tensor else input_shape_or_tensor)
if input_shape[-1] != 4:
raise ValueError("Last dimension of NCHW_VECT_C must be 4.")
permutation = permutations[len(input_shape)]
nhwc_shape = [input_shape[a] for a in permutation[:-1]]
nhwc_shape[-1] *= input_shape[-1]
if is_tensor:
t = array_ops.transpose(input_shape_or_tensor, permutation)
return array_ops.reshape(t, nhwc_shape)
else:
return nhwc_shape
def NCHWToNHWC(input_tensor):
"""Converts the input from the NCHW format to NHWC.
Args:
input_tensor: a 4- or 5-D tensor, or an array representing shape
Returns:
converted tensor or shape array
"""
# tensor dim -> new axis order
new_axes = {4: [0, 2, 3, 1], 5: [0, 2, 3, 4, 1]}
if isinstance(input_tensor, ops.Tensor):
ndims = input_tensor.shape.ndims
return array_ops.transpose(input_tensor, new_axes[ndims])
else:
ndims = len(input_tensor)
return [input_tensor[a] for a in new_axes[ndims]]
def skip_if(condition):
"""Skips the decorated function if condition is or evaluates to True.
Args:
condition: Either an expression that can be used in "if not condition"
statement, or a callable whose result should be a boolean.
Returns:
The wrapped function
"""
def real_skip_if(fn):
def wrapper(*args, **kwargs):
if callable(condition):
skip = condition()
else:
skip = condition
if not skip:
fn(*args, **kwargs)
return wrapper
return real_skip_if
def enable_c_shapes(fn):
"""Decorator for enabling C shapes on a test.
Note this enables the C shapes after running the test class's setup/teardown
methods.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
# pylint: disable=protected-access
def wrapper(*args, **kwargs):
prev_value = ops._USE_C_SHAPES
ops._USE_C_SHAPES = True
try:
fn(*args, **kwargs)
finally:
ops._USE_C_SHAPES = prev_value
# pylint: enable=protected-access
return wrapper
def with_c_shapes(cls):
"""Adds methods that call original methods but with C API shapes enabled.
Note this enables C shapes in new methods after running the test class's
setup method.
Args:
cls: class to decorate
Returns:
cls with new test methods added
"""
# If C shapes are already enabled, don't do anything. Some tests break if the
# same test is run twice, so this allows us to turn on the C shapes by default
# without breaking these tests.
if ops._USE_C_SHAPES:
return cls
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name + "WithCShapes", enable_c_shapes(value))
return cls
def enable_cond_v2(fn):
"""Decorator for enabling CondV2 on a test.
Note this enables using CondV2 after running the test class's setup/teardown
methods.
Args:
fn: the function to be wrapped
Returns:
The wrapped function
"""
# pylint: disable=protected-access
def wrapper(*args, **kwargs):
prev_value = control_flow_ops._ENABLE_COND_V2
control_flow_ops._ENABLE_COND_V2 = True
try:
fn(*args, **kwargs)
finally:
control_flow_ops._ENABLE_COND_V2 = prev_value
# pylint: enable=protected-access
return wrapper
def with_cond_v2(cls):
"""Adds methods that call original methods but with CondV2 enabled.
Note this enables CondV2 in new methods after running the test class's
setup method.
Args:
cls: class to decorate
Returns:
cls with new test methods added
"""
if control_flow_ops._ENABLE_COND_V2:
return cls
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name + "WithCondV2", enable_cond_v2(value))
return cls
def assert_no_new_pyobjects_executing_eagerly(f):
"""Decorator for asserting that no new Python objects persist after a test.
Runs the test multiple times executing eagerly, first as a warmup and then
several times to let objects accumulate. The warmup helps ignore caches which
do not grow as the test is run repeatedly.
Useful for checking that there are no missing Py_DECREFs in the C exercised by
a bit of Python.
"""
def decorator(self, **kwargs):
"""Warms up, gets an object count, runs the test, checks for new objects."""
with context.eager_mode():
gc.disable()
f(self, **kwargs)
gc.collect()
previous_count = len(gc.get_objects())
if ops.has_default_graph():
collection_sizes_before = {
collection: len(ops.get_collection(collection))
for collection in ops.get_default_graph().collections
}
for _ in range(3):
f(self, **kwargs)
# Note that gc.get_objects misses anything that isn't subject to garbage
# collection (C types). Collections are a common source of leaks, so we
# test for collection sizes explicitly.
if ops.has_default_graph():
for collection_key in ops.get_default_graph().collections:
collection = ops.get_collection(collection_key)
size_before = collection_sizes_before.get(collection_key, 0)
if len(collection) > size_before:
raise AssertionError(
("Collection %s increased in size from "
"%d to %d (current items %s).") %
(collection_key, size_before, len(collection), collection))
# Make sure our collection checks don't show up as leaked memory by
# removing references to temporary variables.
del collection
del collection_key
del size_before
del collection_sizes_before
gc.collect()
# There should be no new Python objects hanging around.
new_count = len(gc.get_objects())
# In some cases (specifacally on MacOS), new_count is somehow
# smaller than previous_count.
# Using plain assert because not all classes using this decorator
# have assertLessEqual
assert new_count <= previous_count, (
"new_count(%d) is not less than or equal to previous_count(%d)" %
(new_count, previous_count))
gc.enable()
return decorator
def assert_no_new_tensors(f):
"""Decorator for asserting that no new Tensors persist after a test.
Mainly useful for checking that code using the Python C API has correctly
manipulated reference counts.
Clears the caches that it knows about, runs the garbage collector, then checks
that there are no Tensor or Tensor-like objects still around. This includes
Tensors to which something still has a reference (e.g. from missing
Py_DECREFs) and uncollectable cycles (i.e. Python reference cycles where one
of the objects has __del__ defined).
Args:
f: The test case to run.
Returns:
The decorated test case.
"""
def decorator(self, **kwargs):
"""Finds existing Tensors, runs the test, checks for new Tensors."""
def _is_tensorflow_object(obj):
try:
return isinstance(obj,
(ops.Tensor, variables.Variable,
tensor_shape.Dimension, tensor_shape.TensorShape))
except ReferenceError:
# If the object no longer exists, we don't care about it.
return False
tensors_before = set(
id(obj) for obj in gc.get_objects() if _is_tensorflow_object(obj))
if context.executing_eagerly():
f(self, **kwargs)
ops.reset_default_graph()
else:
# Run the test in a new graph so that collections get cleared when it's
# done, but inherit the graph key so optimizers behave.
outside_graph_key = ops.get_default_graph()._graph_key
with ops.Graph().as_default():
ops.get_default_graph()._graph_key = outside_graph_key
f(self, **kwargs)
# Make an effort to clear caches, which would otherwise look like leaked
# Tensors.
context.context()._clear_caches() # pylint: disable=protected-access
gc.collect()
tensors_after = [
obj for obj in gc.get_objects()
if _is_tensorflow_object(obj) and id(obj) not in tensors_before
]
if tensors_after:
raise AssertionError(("%d Tensors not deallocated after test: %s" % (
len(tensors_after),
str(tensors_after),
)))
return decorator
def assert_no_garbage_created(f):
"""Test method decorator to assert that no garbage has been created.
Note that this decorator sets DEBUG_SAVEALL, which in some Python interpreters
cannot be un-set (i.e. will disable garbage collection for any other unit
tests in the same file/shard).
Args:
f: The function to decorate.
Returns:
The decorated function.
"""
def decorator(self, **kwargs):
"""Sets DEBUG_SAVEALL, runs the test, and checks for new garbage."""
gc.disable()
previous_debug_flags = gc.get_debug()
gc.set_debug(gc.DEBUG_SAVEALL)
gc.collect()
previous_garbage = len(gc.garbage)
f(self, **kwargs)
gc.collect()
if len(gc.garbage) > previous_garbage:
logging.error(
"The decorated test created work for Python's garbage collector, "
"likely due to a reference cycle. New objects in cycle(s):")
for i, obj in enumerate(gc.garbage[previous_garbage:]):
try:
logging.error("Object %d of %d", i,
len(gc.garbage) - previous_garbage)
def _safe_object_str(obj):
return "<%s %d>" % (obj.__class__.__name__, id(obj))
logging.error(" Object type: %s", _safe_object_str(obj))
logging.error(
" Referrer types: %s", ", ".join(
[_safe_object_str(ref) for ref in gc.get_referrers(obj)]))
logging.error(
" Referent types: %s", ", ".join(
[_safe_object_str(ref) for ref in gc.get_referents(obj)]))
logging.error(" Object attribute names: %s", dir(obj))
logging.error(" Object __str__:")
logging.error(obj)
logging.error(" Object __repr__:")
logging.error(repr(obj))
except Exception:
logging.error("(Exception while printing object)")
# This will fail if any garbage has been created, typically because of a
# reference cycle.
self.assertEqual(previous_garbage, len(gc.garbage))
# TODO(allenl): Figure out why this debug flag reset doesn't work. It would
# be nice to be able to decorate arbitrary tests in a large test suite and
# not hold on to every object in other tests.
gc.set_debug(previous_debug_flags)
gc.enable()
return decorator
def _combine_named_parameters(**kwargs):
"""Generate combinations based on its keyword arguments.
Two sets of returned combinations can be concatenated using +. Their product
can be computed using `times()`.
Args:
**kwargs: keyword arguments of form `option=[possibilities, ...]`
or `option=the_only_possibility`.
Returns:
a list of dictionaries for each combination. Keys in the dictionaries are
the keyword argument names. Each key has one value - one of the
corresponding keyword argument values.
"""
if not kwargs:
return [OrderedDict()]
sort_by_key = lambda k: k[0][0]
kwargs = OrderedDict(sorted(kwargs.items(), key=sort_by_key))
first = list(kwargs.items())[0]
rest = dict(list(kwargs.items())[1:])
rest_combined = _combine_named_parameters(**rest)
key = first[0]
values = first[1]
if not isinstance(values, list):
values = [values]
combinations = [
OrderedDict(sorted(list(combined.items()) + [(key, v)], key=sort_by_key))
for v in values
for combined in rest_combined
]
return combinations
def generate_combinations_with_testcase_name(**kwargs):
"""Generate combinations based on its keyword arguments using combine().
This function calls combine() and appends a testcase name to the list of
dictionaries returned. The 'testcase_name' key is a required for named
parameterized tests.
Args:
**kwargs: keyword arguments of form `option=[possibilities, ...]`
or `option=the_only_possibility`.
Returns:
a list of dictionaries for each combination. Keys in the dictionaries are
the keyword argument names. Each key has one value - one of the
corresponding keyword argument values.
"""
combinations = _combine_named_parameters(**kwargs)
named_combinations = []
for combination in combinations:
assert isinstance(combination, OrderedDict)
name = "".join([
"_{}_{}".format("".join(filter(str.isalnum, key)), "".join(
filter(str.isalnum, str(value))))
for key, value in combination.items()
])
named_combinations.append(
OrderedDict(
list(combination.items()) + [("testcase_name",
"_test{}".format(name))]))
return named_combinations
def run_all_in_graph_and_eager_modes(cls):
"""Execute all test methods in the given class with and without eager."""
base_decorator = run_in_graph_and_eager_modes
for name, value in cls.__dict__.copy().items():
if callable(value) and name.startswith("test"):
setattr(cls, name, base_decorator(value))
return cls
def run_in_graph_and_eager_modes(func=None,
config=None,
use_gpu=True,
reset_test=True,
assert_no_eager_garbage=False):
"""Execute the decorated test with and without enabling eager execution.
This function returns a decorator intended to be applied to test methods in
a `tf.test.TestCase` class. Doing so will cause the contents of the test
method to be executed twice - once normally, and once with eager execution
enabled. This allows unittests to confirm the equivalence between eager
and graph execution (see `tf.enable_eager_execution`).
For example, consider the following unittest:
```python
class MyTests(tf.test.TestCase):
@run_in_graph_and_eager_modes
def test_foo(self):
x = tf.constant([1, 2])
y = tf.constant([3, 4])
z = tf.add(x, y)
self.assertAllEqual([4, 6], self.evaluate(z))
if __name__ == "__main__":
tf.test.main()
```
This test validates that `tf.add()` has the same behavior when computed with
eager execution enabled as it does when constructing a TensorFlow graph and
executing the `z` tensor in a session.
Args:
func: function to be annotated. If `func` is None, this method returns a
decorator the can be applied to a function. If `func` is not None this
returns the decorator applied to `func`.
config: An optional config_pb2.ConfigProto to use to configure the
session when executing graphs.
use_gpu: If True, attempt to run as many operations as possible on GPU.
reset_test: If True, tearDown and SetUp the test case between the two
executions of the test (once with and once without eager execution).
assert_no_eager_garbage: If True, sets DEBUG_SAVEALL on the garbage
collector and asserts that no extra garbage has been created when running
the test with eager execution enabled. This will fail if there are
reference cycles (e.g. a = []; a.append(a)). Off by default because some
tests may create garbage for legitimate reasons (e.g. they define a class
which inherits from `object`), and because DEBUG_SAVEALL is sticky in some
Python interpreters (meaning that tests which rely on objects being
collected elsewhere in the unit test file will not work). Additionally,
checks that nothing still has a reference to Tensors that the test
allocated.
Returns:
Returns a decorator that will run the decorated test method twice:
once by constructing and executing a graph in a session and once with
eager execution enabled.
"""
def decorator(f):
if tf_inspect.isclass(f):
raise ValueError(
"`run_test_in_graph_and_eager_modes` only supports test methods. "
"Did you mean to use `run_all_tests_in_graph_and_eager_modes`?")
def decorated(self, **kwargs):
try:
with context.graph_mode():
with self.test_session(use_gpu=use_gpu, config=config):
f(self, **kwargs)
except unittest.case.SkipTest:
pass
def run_eagerly(self, **kwargs):
if not use_gpu:
with ops.device("/cpu:0"):
f(self, **kwargs)
else:
f(self, **kwargs)
if assert_no_eager_garbage:
ops.reset_default_graph()
run_eagerly = assert_no_new_tensors(
assert_no_garbage_created(run_eagerly))
if reset_test:
# This decorator runs the wrapped test twice.
# Reset the test environment between runs.
self.tearDown()
self._tempdir = None
# Create a new graph for the eagerly executed version of this test for
# better isolation.
graph_for_eager_test = ops.Graph()
with graph_for_eager_test.as_default(), context.eager_mode():
if reset_test:
self.setUp()
run_eagerly(self, **kwargs)
ops.dismantle_graph(graph_for_eager_test)
return decorated
if func is not None:
return decorator(func)
return decorator
@tf_export("test.is_gpu_available")
def is_gpu_available(cuda_only=False, min_cuda_compute_capability=None):
"""Returns whether TensorFlow can access a GPU.
Args:
cuda_only: limit the search to CUDA gpus.
min_cuda_compute_capability: a (major,minor) pair that indicates the minimum
CUDA compute capability required, or None if no requirement.
Returns:
True iff a gpu device of the requested kind is available.
"""
def compute_capability_from_device_desc(device_desc):
# TODO(jingyue): The device description generator has to be in sync with
# this file. Another option is to put compute capability in
# DeviceAttributes, but I avoided that to keep DeviceAttributes
# target-independent. Reconsider this option when we have more things like
# this to keep in sync.
# LINT.IfChange
match = re.search(r"compute capability: (\d+)\.(\d+)", device_desc)
# LINT.ThenChange(//tensorflow/core/\
# common_runtime/gpu/gpu_device.cc)
if not match:
return 0, 0
return int(match.group(1)), int(match.group(2))
try:
for local_device in device_lib.list_local_devices():
if local_device.device_type == "GPU":
if (min_cuda_compute_capability is None or
compute_capability_from_device_desc(
local_device.physical_device_desc) >=
min_cuda_compute_capability):
return True
if local_device.device_type == "SYCL" and not cuda_only:
return True
return False
except errors_impl.NotFoundError as e:
if not all([x in str(e) for x in ["CUDA", "not find"]]):
raise e
else:
logging.error(str(e))
return False
@contextlib.contextmanager
def device(use_gpu):
"""Uses gpu when requested and available."""
if use_gpu and is_gpu_available():
dev = "/device:GPU:0"
else:
dev = "/device:CPU:0"
with ops.device(dev):
yield
class ErrorLoggingSession(session.Session):
"""Wrapper around a Session that logs errors in run().
"""
def run(self, *args, **kwargs):
try:
return super(ErrorLoggingSession, self).run(*args, **kwargs)
except Exception as e: # pylint: disable=broad-except
logging.error(str(e))
raise
@tf_export("test.TestCase")
class TensorFlowTestCase(googletest.TestCase):
"""Base class for tests that need to test TensorFlow.
"""
def __init__(self, methodName="runTest"): # pylint: disable=invalid-name
super(TensorFlowTestCase, self).__init__(methodName)
self._threads = []
self._tempdir = None
self._cached_session = None
def setUp(self):
self._ClearCachedSession()
random.seed(random_seed.DEFAULT_GRAPH_SEED)
np.random.seed(random_seed.DEFAULT_GRAPH_SEED)
# Note: The following line is necessary because some test methods may error
# out from within nested graph contexts (e.g., via assertRaises and
# assertRaisesRegexp), which may leave ops._default_graph_stack non-empty
# under certain versions of Python. That would cause
# ops.reset_default_graph() to throw an exception if the stack were not
# cleared first.
ops._default_graph_stack.reset() # pylint: disable=protected-access
ops.reset_default_graph()
random_seed.set_random_seed(random_seed.DEFAULT_GRAPH_SEED)
def tearDown(self):
for thread in self._threads:
thread.check_termination()
self._ClearCachedSession()
def _ClearCachedSession(self):
if self._cached_session is not None:
self._cached_session.close()
self._cached_session = None
def get_temp_dir(self):
"""Returns a unique temporary directory for the test to use.
If you call this method multiple times during in a test, it will return the
same folder. However, across different runs the directories will be
different. This will ensure that across different runs tests will not be
able to pollute each others environment.
If you need multiple unique directories within a single test, you should
use tempfile.mkdtemp as follows:
tempfile.mkdtemp(dir=self.get_temp_dir()):
Returns:
string, the path to the unique temporary directory created for this test.
"""
if not self._tempdir:
self._tempdir = tempfile.mkdtemp(dir=googletest.GetTempDir())
return self._tempdir
def _AssertProtoEquals(self, a, b, msg=None):
"""Asserts that a and b are the same proto.
Uses ProtoEq() first, as it returns correct results
for floating point attributes, and then use assertProtoEqual()
in case of failure as it provides good error messages.
Args:
a: a proto.
b: another proto.
msg: Optional message to report on failure.
"""
if not compare.ProtoEq(a, b):
compare.assertProtoEqual(self, a, b, normalize_numbers=True, msg=msg)
def assertProtoEquals(self, expected_message_maybe_ascii, message, msg=None):
"""Asserts that message is same as parsed expected_message_ascii.
Creates another prototype of message, reads the ascii message into it and
then compares them using self._AssertProtoEqual().
Args:
expected_message_maybe_ascii: proto message in original or ascii form.
message: the message to validate.
msg: Optional message to report on failure.
"""
msg = msg if msg else ""
if isinstance(expected_message_maybe_ascii, type(message)):
expected_message = expected_message_maybe_ascii
self._AssertProtoEquals(expected_message, message)
elif isinstance(expected_message_maybe_ascii, str):
expected_message = type(message)()
text_format.Merge(
expected_message_maybe_ascii,
expected_message,
descriptor_pool=descriptor_pool.Default())
self._AssertProtoEquals(expected_message, message, msg=msg)
else:
assert False, ("Can't compare protos of type %s and %s. %s" %
(type(expected_message_maybe_ascii), type(message), msg))
def assertProtoEqualsVersion(
self,
expected,
actual,
producer=versions.GRAPH_DEF_VERSION,
min_consumer=versions.GRAPH_DEF_VERSION_MIN_CONSUMER,
msg=None):
expected = "versions { producer: %d min_consumer: %d };\n%s" % (
producer, min_consumer, expected)
self.assertProtoEquals(expected, actual, msg=msg)
def assertStartsWith(self, actual, expected_start, msg=None):
"""Assert that actual.startswith(expected_start) is True.
Args:
actual: str
expected_start: str
msg: Optional message to report on failure.
"""
if not actual.startswith(expected_start):
fail_msg = "%r does not start with %r" % (actual, expected_start)
fail_msg += " : %r" % (msg) if msg else ""
self.fail(fail_msg)