/
variables.py
2733 lines (2237 loc) · 100 KB
/
variables.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Variable class."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import enum # pylint: disable=g-bad-import-order
import six
from tensorflow.core.framework import attr_value_pb2
from tensorflow.core.framework import variable_pb2
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import control_flow_ops
from tensorflow.python.ops import gen_array_ops
from tensorflow.python.ops import gen_state_ops
from tensorflow.python.ops import math_ops
from tensorflow.python.ops import state_ops
from tensorflow.python.platform import tf_logging as logging
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import compat
from tensorflow.python.util import tf_should_use
from tensorflow.python.util.deprecation import deprecated
from tensorflow.python.util.tf_export import tf_export
def default_variable_creator(_, **kwds):
del kwds
raise NotImplementedError("variable_scope needs to be imported")
def _make_getter(captured_getter, captured_previous):
"""To avoid capturing loop variables."""
def getter(**kwargs):
return captured_getter(captured_previous, **kwargs)
return getter
@tf_export("VariableSynchronization")
class VariableSynchronization(enum.Enum):
"""Indicates when a distributed variable will be synced.
* `AUTO`: Indicates that the synchronization will be determined by the current
`DistributionStrategy` (eg. With `MirroredStrategy` this would be
`ON_WRITE`).
* `NONE`: Indicates that there will only be one copy of the variable, so
there is no need to sync.
* `ON_WRITE`: Indicates that the variable will be updated across devices
every time it is written.
* `ON_READ`: Indicates that the variable will be aggregated across devices
when it is read (eg. when checkpointing or when evaluating an op that uses
the variable).
"""
AUTO = 0
NONE = 1
ON_WRITE = 2
ON_READ = 3
@tf_export("VariableAggregation")
class VariableAggregation(enum.Enum):
"""Indicates how a distributed variable will be aggregated.
`tf.contrib.distribute.DistributionStrategy` distributes a model by making
multiple copies (called "towers") acting data-parallel on different elements
of the input batch. When performing some variable-update operation, say
`var.assign_add(x)`, in a model, we need to resolve how to combine the
different values for `x` computed in the different towers.
* `NONE`: This is the default, giving an error if you use a
variable-update operation with multiple towers.
* `SUM`: Add the updates across towers.
* `MEAN`: Take the arithmetic mean ("average") of the updates across towers.
* `ONLY_FIRST_TOWER`: This is for when every tower is performing the same
update, but we only want to perform the update once. Used, e.g., for the
global step counter.
"""
NONE = 0
SUM = 1
MEAN = 2
ONLY_FIRST_TOWER = 3
class VariableMetaclass(type):
"""Metaclass to allow construction of tf.Variable to be overridden."""
def _variable_call(cls,
initial_value=None,
trainable=None,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Call on Variable class. Useful to force the signature."""
previous_getter = lambda **kwargs: default_variable_creator(None, **kwargs)
for getter in ops.get_default_graph()._variable_creator_stack: # pylint: disable=protected-access
previous_getter = _make_getter(getter, previous_getter)
# Reset `aggregation` that is explicitly set as `None` to the enum NONE.
if aggregation is None:
aggregation = VariableAggregation.NONE
return previous_getter(
initial_value=initial_value,
trainable=trainable,
collections=collections,
validate_shape=validate_shape,
caching_device=caching_device,
name=name,
variable_def=variable_def,
dtype=dtype,
expected_shape=expected_shape,
import_scope=import_scope,
constraint=constraint,
use_resource=use_resource,
synchronization=synchronization,
aggregation=aggregation)
def __call__(cls, *args, **kwargs):
if cls is Variable:
return cls._variable_call(*args, **kwargs)
else:
return super(VariableMetaclass, cls).__call__(*args, **kwargs)
@tf_export("Variable")
class Variable(six.with_metaclass(VariableMetaclass,
checkpointable.CheckpointableBase)):
"""See the [Variables Guide](https://tensorflow.org/guide/variables).
A variable maintains state in the graph across calls to `run()`. You add a
variable to the graph by constructing an instance of the class `Variable`.
The `Variable()` constructor requires an initial value for the variable,
which can be a `Tensor` of any type and shape. The initial value defines the
type and shape of the variable. After construction, the type and shape of
the variable are fixed. The value can be changed using one of the assign
methods.
If you want to change the shape of a variable later you have to use an
`assign` Op with `validate_shape=False`.
Just like any `Tensor`, variables created with `Variable()` can be used as
inputs for other Ops in the graph. Additionally, all the operators
overloaded for the `Tensor` class are carried over to variables, so you can
also add nodes to the graph by just doing arithmetic on variables.
```python
import tensorflow as tf
# Create a variable.
w = tf.Variable(<initial-value>, name=<optional-name>)
# Use the variable in the graph like any Tensor.
y = tf.matmul(w, ...another variable or tensor...)
# The overloaded operators are available too.
z = tf.sigmoid(w + y)
# Assign a new value to the variable with `assign()` or a related method.
w.assign(w + 1.0)
w.assign_add(1.0)
```
When you launch the graph, variables have to be explicitly initialized before
you can run Ops that use their value. You can initialize a variable by
running its *initializer op*, restoring the variable from a save file, or
simply running an `assign` Op that assigns a value to the variable. In fact,
the variable *initializer op* is just an `assign` Op that assigns the
variable's initial value to the variable itself.
```python
# Launch the graph in a session.
with tf.Session() as sess:
# Run the variable initializer.
sess.run(w.initializer)
# ...you now can run ops that use the value of 'w'...
```
The most common initialization pattern is to use the convenience function
`global_variables_initializer()` to add an Op to the graph that initializes
all the variables. You then run that Op after launching the graph.
```python
# Add an Op to initialize global variables.
init_op = tf.global_variables_initializer()
# Launch the graph in a session.
with tf.Session() as sess:
# Run the Op that initializes global variables.
sess.run(init_op)
# ...you can now run any Op that uses variable values...
```
If you need to create a variable with an initial value dependent on another
variable, use the other variable's `initialized_value()`. This ensures that
variables are initialized in the right order.
All variables are automatically collected in the graph where they are
created. By default, the constructor adds the new variable to the graph
collection `GraphKeys.GLOBAL_VARIABLES`. The convenience function
`global_variables()` returns the contents of that collection.
When building a machine learning model it is often convenient to distinguish
between variables holding the trainable model parameters and other variables
such as a `global step` variable used to count training steps. To make this
easier, the variable constructor supports a `trainable=<bool>` parameter. If
`True`, the new variable is also added to the graph collection
`GraphKeys.TRAINABLE_VARIABLES`. The convenience function
`trainable_variables()` returns the contents of this collection. The
various `Optimizer` classes use this collection as the default list of
variables to optimize.
WARNING: tf.Variable objects by default have a non-intuitive memory model. A
Variable is represented internally as a mutable Tensor which can
non-deterministically alias other Tensors in a graph. The set of operations
which consume a Variable and can lead to aliasing is undetermined and can
change across TensorFlow versions. Avoid writing code which relies on the
value of a Variable either changing or not changing as other operations
happen. For example, using Variable objects or simple functions thereof as
predicates in a `tf.cond` is dangerous and error-prone:
```
v = tf.Variable(True)
tf.cond(v, lambda: v.assign(False), my_false_fn) # Note: this is broken.
```
Here replacing adding `use_resource=True` when constructing the variable will
fix any nondeterminism issues:
```
v = tf.Variable(True, use_resource=True)
tf.cond(v, lambda: v.assign(False), my_false_fn)
```
To use the replacement for variables which does
not have these issues:
* Add `use_resource=True` when constructing `tf.Variable`;
* Call `tf.get_variable_scope().set_use_resource(True)` inside a
`tf.variable_scope` before the `tf.get_variable()` call.
"""
def __init__(self,
initial_value=None,
trainable=True,
collections=None,
validate_shape=True,
caching_device=None,
name=None,
variable_def=None,
dtype=None,
expected_shape=None,
import_scope=None,
constraint=None,
use_resource=None,
synchronization=VariableSynchronization.AUTO,
aggregation=VariableAggregation.NONE):
"""Creates a new variable with value `initial_value`.
The new variable is added to the graph collections listed in `collections`,
which defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
If `trainable` is `True` the variable is also added to the graph collection
`GraphKeys.TRAINABLE_VARIABLES`.
This constructor creates both a `variable` Op and an `assign` Op to set the
variable to its initial value.
Args:
initial_value: A `Tensor`, or Python object convertible to a `Tensor`,
which is the initial value for the Variable. The initial value must have
a shape specified unless `validate_shape` is set to False. Can also be a
callable with no argument that returns the initial value when called. In
that case, `dtype` must be specified. (Note that initializer functions
from init_ops.py must first be bound to a shape before being used here.)
trainable: If `True`, the default, also adds the variable to the graph
collection `GraphKeys.TRAINABLE_VARIABLES`. This collection is used as
the default list of variables to use by the `Optimizer` classes.
collections: List of graph collections keys. The new variable is added to
these collections. Defaults to `[GraphKeys.GLOBAL_VARIABLES]`.
validate_shape: If `False`, allows the variable to be initialized with a
value of unknown shape. If `True`, the default, the shape of
`initial_value` must be known.
caching_device: Optional device string describing where the Variable
should be cached for reading. Defaults to the Variable's device.
If not `None`, caches on another device. Typical use is to cache
on the device where the Ops using the Variable reside, to deduplicate
copying through `Switch` and other conditional statements.
name: Optional name for the variable. Defaults to `'Variable'` and gets
uniquified automatically.
variable_def: `VariableDef` protocol buffer. If not `None`, recreates
the Variable object with its contents, referencing the variable's nodes
in the graph, which must already exist. The graph is not changed.
`variable_def` and the other arguments are mutually exclusive.
dtype: If set, initial_value will be converted to the given type.
If `None`, either the datatype will be kept (if `initial_value` is
a Tensor), or `convert_to_tensor` will decide.
expected_shape: A TensorShape. If set, initial_value is expected
to have this shape.
import_scope: Optional `string`. Name scope to add to the
`Variable.` Only used when initializing from protocol buffer.
constraint: An optional projection function to be applied to the variable
after being updated by an `Optimizer` (e.g. used to implement norm
constraints or value constraints for layer weights). The function must
take as input the unprojected Tensor representing the value of the
variable and return the Tensor for the projected value
(which must have the same shape). Constraints are not safe to
use when doing asynchronous distributed training.
use_resource: if True, a ResourceVariable is created; otherwise an
old-style ref-based variable is created. When eager execution is enabled
a resource variable is always created.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
`tf.VariableSynchronization`. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
when to synchronize. If `synchronization` is set to `ON_READ`,
`trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
`tf.VariableAggregation`.
Raises:
ValueError: If both `variable_def` and initial_value are specified.
ValueError: If the initial value is not specified, or does not have a
shape and `validate_shape` is `True`.
RuntimeError: If eager execution is enabled.
"""
raise NotImplementedError
def __repr__(self):
raise NotImplementedError
def value(self):
"""Returns the last snapshot of this variable.
You usually do not need to call this method as all ops that need the value
of the variable call it automatically through a `convert_to_tensor()` call.
Returns a `Tensor` which holds the value of the variable. You can not
assign a new value to this tensor as it is not a reference to the variable.
To avoid copies, if the consumer of the returned value is on the same device
as the variable, this actually returns the live value of the variable, not
a copy. Updates to the variable are seen by the consumer. If the consumer
is on a different device it will get a copy of the variable.
Returns:
A `Tensor` containing the value of the variable.
"""
raise NotImplementedError
def read_value(self):
"""Returns the value of this variable, read in the current context.
Can be different from value() if it's on another device, with control
dependencies, etc.
Returns:
A `Tensor` containing the value of the variable.
"""
raise NotImplementedError
def set_shape(self, shape):
"""Overrides the shape for this variable.
Args:
shape: the `TensorShape` representing the overridden shape.
"""
raise NotImplementedError
@property
def trainable(self):
raise NotImplementedError
def eval(self, session=None):
"""In a session, computes and returns the value of this variable.
This is not a graph construction method, it does not add ops to the graph.
This convenience method requires a session where the graph
containing this variable has been launched. If no session is
passed, the default session is used. See `tf.Session` for more
information on launching a graph and on sessions.
```python
v = tf.Variable([1, 2])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# Usage passing the session explicitly.
print(v.eval(sess))
# Usage with the default session. The 'with' block
# above makes 'sess' the default session.
print(v.eval())
```
Args:
session: The session to use to evaluate this variable. If
none, the default session is used.
Returns:
A numpy `ndarray` with a copy of the value of this variable.
"""
raise NotImplementedError
def initialized_value(self):
"""Returns the value of the initialized variable.
You should use this instead of the variable itself to initialize another
variable with a value that depends on the value of this variable.
```python
# Initialize 'v' with a random tensor.
v = tf.Variable(tf.truncated_normal([10, 40]))
# Use `initialized_value` to guarantee that `v` has been
# initialized before its value is used to initialize `w`.
# The random values are picked only once.
w = tf.Variable(v.initialized_value() * 2.0)
```
Returns:
A `Tensor` holding the value of this variable after its initializer
has run.
"""
raise NotImplementedError
@property
def initial_value(self):
"""Returns the Tensor used as the initial value for the variable.
Note that this is different from `initialized_value()` which runs
the op that initializes the variable before returning its value.
This method returns the tensor that is used by the op that initializes
the variable.
Returns:
A `Tensor`.
"""
raise NotImplementedError
@property
def constraint(self):
"""Returns the constraint function associated with this variable.
Returns:
The constraint function that was passed to the variable constructor.
Can be `None` if no constraint was passed.
"""
raise NotImplementedError
def assign(self, value, use_locking=False, name=None, read_value=True):
"""Assigns a new value to the variable.
This is essentially a shortcut for `assign(self, value)`.
Args:
value: A `Tensor`. The new value for this variable.
use_locking: If `True`, use locking during the assignment.
name: The name of the operation to be created
read_value: if True, will return something which evaluates to the
new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the assignment has completed.
"""
raise NotImplementedError
def assign_add(self, delta, use_locking=False, name=None, read_value=True):
"""Adds a value to this variable.
This is essentially a shortcut for `assign_add(self, delta)`.
Args:
delta: A `Tensor`. The value to add to this variable.
use_locking: If `True`, use locking during the operation.
name: The name of the operation to be created
read_value: if True, will return something which evaluates to the
new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the addition has completed.
"""
raise NotImplementedError
def assign_sub(self, delta, use_locking=False, name=None, read_value=True):
"""Subtracts a value from this variable.
This is essentially a shortcut for `assign_sub(self, delta)`.
Args:
delta: A `Tensor`. The value to subtract from this variable.
use_locking: If `True`, use locking during the operation.
name: The name of the operation to be created
read_value: if True, will return something which evaluates to the
new value of the variable; if False will return the assign op.
Returns:
A `Tensor` that will hold the new value of this variable after
the subtraction has completed.
"""
raise NotImplementedError
def scatter_sub(self, sparse_delta, use_locking=False, name=None):
"""Subtracts `IndexedSlices` from this variable.
Args:
sparse_delta: `IndexedSlices` to be subtracted from this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_add(self, sparse_delta, use_locking=False, name=None):
"""Adds `IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_update(self, sparse_delta, use_locking=False, name=None):
"""Assigns `IndexedSlices` to this variable.
Args:
sparse_delta: `IndexedSlices` to be assigned to this variable.
use_locking: If `True`, use locking during the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_nd_sub(self, indices, updates, name=None):
"""Applies sparse subtraction to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to add 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
op = ref.scatter_nd_sub(indices, updates)
with tf.Session() as sess:
print sess.run(op)
```
The resulting update to ref would look like this:
[1, -9, 3, -6, -6, 6, 7, -4]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_nd_add(self, indices, updates, name=None):
"""Applies sparse addition to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to add 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
add = ref.scatter_nd_add(indices, updates)
with tf.Session() as sess:
print sess.run(add)
```
The resulting update to ref would look like this:
[1, 13, 3, 14, 14, 6, 7, 20]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def scatter_nd_update(self, indices, updates, name=None):
"""Applies sparse assignment to individual values or slices in a Variable.
`ref` is a `Tensor` with rank `P` and `indices` is a `Tensor` of rank `Q`.
`indices` must be integer tensor, containing indices into `ref`.
It must be shape `[d_0, ..., d_{Q-2}, K]` where `0 < K <= P`.
The innermost dimension of `indices` (with length `K`) corresponds to
indices into elements (if `K = P`) or slices (if `K < P`) along the `K`th
dimension of `ref`.
`updates` is `Tensor` of rank `Q-1+P-K` with shape:
```
[d_0, ..., d_{Q-2}, ref.shape[K], ..., ref.shape[P-1]].
```
For example, say we want to add 4 scattered elements to a rank-1 tensor to
8 elements. In Python, that update would look like this:
```python
ref = tf.Variable([1, 2, 3, 4, 5, 6, 7, 8])
indices = tf.constant([[4], [3], [1] ,[7]])
updates = tf.constant([9, 10, 11, 12])
op = ref.scatter_nd_assign(indices, updates)
with tf.Session() as sess:
print sess.run(op)
```
The resulting update to ref would look like this:
[1, 11, 3, 10, 9, 6, 7, 12]
See `tf.scatter_nd` for more details about how to make updates to
slices.
Args:
indices: The indices to be used in the operation.
updates: The values to be used in the operation.
name: the name of the operation.
Returns:
A `Tensor` that will hold the new value of this variable after
the scattered subtraction has completed.
Raises:
ValueError: if `sparse_delta` is not an `IndexedSlices`.
"""
raise NotImplementedError
def count_up_to(self, limit):
"""Increments this variable until it reaches `limit`.
When that Op is run it tries to increment the variable by `1`. If
incrementing the variable would bring it above `limit` then the Op raises
the exception `OutOfRangeError`.
If no error is raised, the Op outputs the value of the variable before
the increment.
This is essentially a shortcut for `count_up_to(self, limit)`.
Args:
limit: value at which incrementing the variable raises an error.
Returns:
A `Tensor` that will hold the variable value before the increment. If no
other Op modifies this variable, the values produced will all be
distinct.
"""
raise NotImplementedError
def load(self, value, session=None):
"""Load new value into this variable.
Writes new value to variable's memory. Doesn't add ops to the graph.
This convenience method requires a session where the graph
containing this variable has been launched. If no session is
passed, the default session is used. See `tf.Session` for more
information on launching a graph and on sessions.
```python
v = tf.Variable([1, 2])
init = tf.global_variables_initializer()
with tf.Session() as sess:
sess.run(init)
# Usage passing the session explicitly.
v.load([2, 3], sess)
print(v.eval(sess)) # prints [2 3]
# Usage with the default session. The 'with' block
# above makes 'sess' the default session.
v.load([3, 4], sess)
print(v.eval()) # prints [3 4]
```
Args:
value: New variable value
session: The session to use to evaluate this variable. If
none, the default session is used.
Raises:
ValueError: Session is not passed and no default session
"""
raise NotImplementedError
# Conversion to tensor.
@staticmethod
def _TensorConversionFunction(v, dtype=None, name=None, as_ref=False): # pylint: disable=invalid-name
"""Utility function for converting a Variable to a Tensor."""
_ = name
if dtype and not dtype.is_compatible_with(v.dtype):
raise ValueError(
"Incompatible type conversion requested to type '%s' for variable "
"of type '%s'" % (dtype.name, v.dtype.name))
if as_ref:
return v._ref() # pylint: disable=protected-access
else:
return v.value()
@staticmethod
def _OverloadAllOperators(): # pylint: disable=invalid-name
"""Register overloads for all operators."""
for operator in ops.Tensor.OVERLOADABLE_OPERATORS:
Variable._OverloadOperator(operator)
# For slicing, bind getitem differently than a tensor (use SliceHelperVar
# instead)
# pylint: disable=protected-access
setattr(Variable, "__getitem__", array_ops._SliceHelperVar)
@staticmethod
def _OverloadOperator(operator): # pylint: disable=invalid-name
"""Defer an operator overload to `ops.Tensor`.
We pull the operator out of ops.Tensor dynamically to avoid ordering issues.
Args:
operator: string. The operator name.
"""
def _run_op(a, *args):
# pylint: disable=protected-access
return getattr(ops.Tensor, operator)(a._AsTensor(), *args)
# Propagate __doc__ to wrapper
try:
_run_op.__doc__ = getattr(ops.Tensor, operator).__doc__
except AttributeError:
pass
setattr(Variable, operator, _run_op)
# NOTE(mrry): This enables the Variable's overloaded "right" binary
# operators to run when the left operand is an ndarray, because it
# accords the Variable class higher priority than an ndarray, or a
# numpy matrix.
# TODO(mrry): Convert this to using numpy's __numpy_ufunc__
# mechanism, which allows more control over how Variables interact
# with ndarrays.
__array_priority__ = 100
@property
def name(self):
"""The name of this variable."""
raise NotImplementedError
@property
def initializer(self):
"""The initializer operation for this variable."""
raise NotImplementedError
@property
def device(self):
"""The device of this variable."""
raise NotImplementedError
@property
def dtype(self):
"""The `DType` of this variable."""
raise NotImplementedError
@property
def op(self):
"""The `Operation` of this variable."""
raise NotImplementedError
@property
def graph(self):
"""The `Graph` of this variable."""
raise NotImplementedError
@property
def shape(self):
"""The `TensorShape` of this variable.
Returns:
A `TensorShape`.
"""
raise NotImplementedError
def get_shape(self):
"""Alias of Variable.shape."""
raise NotImplementedError
def to_proto(self, export_scope=None):
"""Converts a `Variable` to a `VariableDef` protocol buffer.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `VariableDef` protocol buffer, or `None` if the `Variable` is not
in the specified name scope.
"""
raise NotImplementedError
@staticmethod
def from_proto(variable_def, import_scope=None):
"""Returns a `Variable` object created from `variable_def`."""
return RefVariable(variable_def=variable_def,
import_scope=import_scope)
class SaveSliceInfo(object):
"""Information on how to save this Variable as a slice.
Provides internal support for saving variables as slices of a larger
variable. This API is not public and is subject to change.
Available properties:
* full_name
* full_shape
* var_offset
* var_shape
"""
def __init__(self,
full_name=None,
full_shape=None,
var_offset=None,
var_shape=None,
save_slice_info_def=None,
import_scope=None):
"""Create a `SaveSliceInfo`.
Args:
full_name: Name of the full variable of which this `Variable` is a
slice.
full_shape: Shape of the full variable, as a list of int.
var_offset: Offset of this `Variable` into the full variable, as a
list of int.
var_shape: Shape of this `Variable`, as a list of int.
save_slice_info_def: `SaveSliceInfoDef` protocol buffer. If not `None`,
recreates the SaveSliceInfo object its contents.
`save_slice_info_def` and other arguments are mutually
exclusive.
import_scope: Optional `string`. Name scope to add. Only used
when initializing from protocol buffer.
"""
if save_slice_info_def:
assert isinstance(save_slice_info_def, variable_pb2.SaveSliceInfoDef)
self.full_name = ops.prepend_name_scope(
save_slice_info_def.full_name, import_scope=import_scope)
self.full_shape = [i for i in save_slice_info_def.full_shape]
self.var_offset = [i for i in save_slice_info_def.var_offset]
self.var_shape = [i for i in save_slice_info_def.var_shape]
else:
self.full_name = full_name
self.full_shape = full_shape
self.var_offset = var_offset
self.var_shape = var_shape
@property
def spec(self):
"""Computes the spec string used for saving."""
full_shape_str = " ".join(["%d" % d for d in self.full_shape]) + " "
sl_spec = ":".join([
"%d,%d" % (o, s) for o, s in zip(self.var_offset, self.var_shape)
])
return full_shape_str + sl_spec
def to_proto(self, export_scope=None):
"""Returns a SaveSliceInfoDef() proto.
Args:
export_scope: Optional `string`. Name scope to remove.
Returns:
A `SaveSliceInfoDef` protocol buffer, or None if the `Variable` is not
in the specified name scope.
"""
if (export_scope is None or
self.full_name.startswith(export_scope)):
save_slice_info_def = variable_pb2.SaveSliceInfoDef()
save_slice_info_def.full_name = ops.strip_name_scope(
self.full_name, export_scope)
for i in self.full_shape:
save_slice_info_def.full_shape.append(i)
for i in self.var_offset:
save_slice_info_def.var_offset.append(i)
for i in self.var_shape:
save_slice_info_def.var_shape.append(i)
return save_slice_info_def
else:
return None
def __iadd__(self, other):
raise NotImplementedError
def __isub__(self, other):
raise NotImplementedError
def __imul__(self, other):
raise NotImplementedError
def __idiv__(self, other):
raise NotImplementedError