/
base_layer.py
1919 lines (1663 loc) · 73.3 KB
/
base_layer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright 2015 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
# pylint: disable=protected-access
"""Contains the base Layer class, from which all layers inherit."""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
import collections
import enum # pylint: disable=g-bad-import-order
import inspect # Necessary supplement to tf_inspect to deal with variadic args.
import numpy as np
from six.moves import zip # pylint: disable=redefined-builtin
from tensorflow.python.eager import context
from tensorflow.python.framework import dtypes
from tensorflow.python.framework import ops
from tensorflow.python.framework import tensor_shape
from tensorflow.python.framework import tensor_util
from tensorflow.python.keras import backend
from tensorflow.python.keras import constraints
from tensorflow.python.keras import initializers
from tensorflow.python.keras import regularizers
from tensorflow.python.keras.utils import generic_utils
from tensorflow.python.keras.utils import tf_utils
# A module that only depends on `keras.layers` import these from here.
from tensorflow.python.keras.utils.generic_utils import to_snake_case # pylint: disable=unused-import
from tensorflow.python.keras.utils.tf_utils import is_tensor_or_tensor_list # pylint: disable=unused-import
from tensorflow.python.ops import array_ops
from tensorflow.python.ops import init_ops
from tensorflow.python.ops import variable_scope as vs
from tensorflow.python.ops import variables as tf_variables
from tensorflow.python.training.checkpointable import base as checkpointable
from tensorflow.python.util import function_utils
from tensorflow.python.util import nest
from tensorflow.python.util import tf_decorator
from tensorflow.python.util import tf_inspect
from tensorflow.python.util.tf_export import tf_export
class CallConvention(enum.Enum):
"""Calling conventions for passing `Layer` inputs to `Layer.call`."""
# The Layer takes inputs as its first argument, named "inputs" for
# compatibility with the signature of Layer.__call__. This is the mode assumed
# for Layers which are not subclassed Models.
EXPLICIT_INPUTS_ARGUMENT = 1
# The Layer takes a single positional argument, not named "inputs". It's
# treated like an "inputs" argument.
SINGLE_POSITIONAL_ARGUMENT = 2
# The Layer has multiple positional arguments to which its inputs should be
# bound.
POSITIONAL_ARGUMENTS_ARE_INPUTS = 3
@tf_export('keras.layers.Layer')
class Layer(checkpointable.CheckpointableBase):
"""Base layer class.
This is the class from which all layers inherit.
A layer is a class implementing common neural networks operations, such
as convolution, batch norm, etc. These operations require managing weights,
losses, updates, and inter-layer connectivity.
Users will just instantiate a layer and then treat it as a callable.
We recommend that descendants of `Layer` implement the following methods:
* `__init__()`: Save configuration in member variables
* `build()`: Called once from `__call__`, when we know the shapes of inputs
and `dtype`. Should have the calls to `add_weight()`, and then
call the super's `build()` (which sets `self.built = True`, which is
nice in case the user wants to call `build()` manually before the
first `__call__`).
* `call()`: Called in `__call__` after making sure `build()` has been called
once. Should actually perform the logic of applying the layer to the
input tensors (which should be passed in as the first argument).
Arguments:
trainable: Boolean, whether the layer's variables should be trainable.
name: String name of the layer.
dtype: Default dtype of the layer's weights (default of `None` means use the
type of the first input).
Read-only properties:
name: The name of the layer (string).
dtype: Default dtype of the layer's weights (default of `None` means use the
type of the first input).
trainable_variables: List of trainable variables.
non_trainable_variables: List of non-trainable variables.
variables: List of all variables of this layer, trainable and
non-trainable.
updates: List of update ops of this layer.
losses: List of losses added by this layer.
trainable_weights: List of variables to be included in backprop.
non_trainable_weights: List of variables that should not be
included in backprop.
weights: The concatenation of the lists trainable_weights and
non_trainable_weights (in this order).
Mutable properties:
trainable: Whether the layer should be trained (boolean).
input_spec: Optional (list of) `InputSpec` object(s) specifying the
constraints on inputs that can be accepted by the layer.
"""
@checkpointable.no_automatic_dependency_tracking
def __init__(self, trainable=True, name=None, dtype=None, **kwargs):
# These properties should be set by the user via keyword arguments.
# note that 'dtype', 'input_shape' and 'batch_input_shape'
# are only applicable to input layers: do not pass these keywords
# to non-input layers.
allowed_kwargs = {
'input_shape',
'batch_input_shape',
'batch_size',
'weights',
'activity_regularizer',
}
# Validate optional keyword arguments.
for kwarg in kwargs:
if kwarg not in allowed_kwargs:
raise TypeError('Keyword argument not understood:', kwarg)
# Mutable properties
# Indicates whether the layer's weights are updated during training
# and whether the layer's updates are run during training
self.trainable = trainable
# A stateful layer is a layer whose updates are run during inference too,
# for instance stateful RNNs.
self.stateful = False
# Indicates whether `build` needs to be called upon layer call, to create
# the layer's weights.
self.built = False
# Provides information about which inputs are compatible with the layer.
self.input_spec = None
self._init_set_name(name)
activity_regularizer = kwargs.pop('activity_regularizer', None)
if activity_regularizer and context.executing_eagerly():
raise ValueError(
('Activity regularization is not supported when executing eagerly. '
'Got activity_regularizer=%s') % (activity_regularizer,))
self._activity_regularizer = activity_regularizer
self._trainable_weights = []
self._non_trainable_weights = []
self._updates = []
# When executing eagerly, _losses is a list of zero-argument lambdas which
# return tensors. When using graph execution, _losses is a list of ops.
self._losses = []
self._dtype = None if dtype is None else dtypes.as_dtype(dtype).name
self._call_fn_args = function_utils.fn_args(self.call)
self._compute_previous_mask = ('mask' in self._call_fn_args or
hasattr(self, 'compute_mask'))
self._call_convention = CallConvention.EXPLICIT_INPUTS_ARGUMENT
# These lists will be filled via successive calls
# to self._add_inbound_node().
self._inbound_nodes = []
self._outbound_nodes = []
self.supports_masking = False
# Manage input shape information if passed.
if 'input_shape' in kwargs or 'batch_input_shape' in kwargs:
# In this case we will later create an input layer
# to insert before the current layer
if 'batch_input_shape' in kwargs:
batch_input_shape = tuple(kwargs['batch_input_shape'])
elif 'input_shape' in kwargs:
if 'batch_size' in kwargs:
batch_size = kwargs['batch_size']
else:
batch_size = None
batch_input_shape = (batch_size,) + tuple(kwargs['input_shape'])
self._batch_input_shape = batch_input_shape
# Manage initial weight values if passed.
if 'weights' in kwargs:
self._initial_weights = kwargs['weights']
else:
self._initial_weights = None
def _init_set_name(self, name, zero_based=True):
if not name:
self._name = unique_layer_name(
generic_utils.to_snake_case(self.__class__.__name__),
zero_based=zero_based)
else:
self._name = name
@property
def dtype(self):
return self._dtype
@property
def name(self):
return self._name
@property
def activity_regularizer(self):
"""Optional regularizer function for the output of this layer."""
return self._activity_regularizer
@activity_regularizer.setter
def activity_regularizer(self, regularizer):
"""Optional regularizer function for the output of this layer."""
self._activity_regularizer = self._no_dependency(regularizer)
@property
def trainable_weights(self):
return self._trainable_weights if self.trainable else []
@property
def non_trainable_weights(self):
if self.trainable:
return self._non_trainable_weights
else:
return self._trainable_weights + self._non_trainable_weights
@property
def trainable_variables(self):
return self.trainable_weights
@property
def non_trainable_variables(self):
return self.non_trainable_weights
@property
def weights(self):
"""Returns the list of all layer variables/weights.
Returns:
A list of variables.
"""
return self.trainable_weights + self.non_trainable_weights
@property
def variables(self):
"""Returns the list of all layer variables/weights.
Returns:
A list of variables.
"""
return self.weights
@property
def updates(self):
if context.executing_eagerly():
raise RuntimeError('Layer.updates not supported in Eager mode.')
if not self.trainable and not self.stateful:
return []
return self._updates
def add_update(self, updates, inputs=None):
"""Add update op(s), potentially dependent on layer inputs.
Weight updates (for instance, the updates of the moving mean and variance
in a BatchNormalization layer) may be dependent on the inputs passed
when calling a layer. Hence, when reusing the same layer on
different inputs `a` and `b`, some entries in `layer.updates` may be
dependent on `a` and some on `b`. This method automatically keeps track
of dependencies.
The `get_updates_for` method allows to retrieve the updates relevant to a
specific set of inputs.
This call is ignored when eager execution is enabled (in that case, variable
updates are run on the fly and thus do not need to be tracked for later
execution).
Arguments:
updates: Update op, or list/tuple of update ops.
inputs: If anything other than None is passed, it signals the updates
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
This is the case for BatchNormalization updates, for instance.
If None, the updates will be taken into account unconditionally,
and you are responsible for making sure that any dependency they might
have is available at runtime.
A step counter might fall into this category.
"""
if context.executing_eagerly():
return # Updates already applied when in eager mode.
def process_update(x):
if isinstance(x, ops.Operation):
return x
elif hasattr(x, 'op'):
return x.op
else:
return ops.convert_to_tensor(x)
updates = generic_utils.to_list(updates)
updates = [process_update(x) for x in updates]
self._updates += updates
if inputs is None:
for u in updates:
u._unconditional_update = True # pylint: disable=protected-access
else:
for u in updates:
u._unconditional_update = False # pylint: disable=protected-access
def get_updates_for(self, inputs):
"""Retrieves updates relevant to a specific set of inputs.
Arguments:
inputs: Input tensor or list/tuple of input tensors.
Returns:
List of update ops of the layer that depend on `inputs`.
Raises:
RuntimeError: If called in Eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('`get_updates_for()` not supported in Eager mode.')
# Updates disabled if layer is not trainable and not explicitly stateful.
if not self.trainable and not self.stateful:
return []
if inputs is None:
# Requesting unconditional updates.
return [x for x in self.updates if x._unconditional_update] # pylint: disable=protected-access
# Requesting input-conditional updates.
inputs = nest.flatten(inputs)
reachable = tf_utils.get_reachable_from_inputs(inputs, self.updates)
updates = []
for update in self.updates:
if update in reachable:
updates.append(update)
return updates
@property
def losses(self):
"""Losses which are associated with this `Layer`.
Note that when executing eagerly, getting this property evaluates
regularizers. When using graph execution, variable regularization ops have
already been created and are simply returned here.
Returns:
A list of tensors.
"""
if context.executing_eagerly():
# _losses may only contain variable regularization losses when executing
# eagerly, and they have been saved as lambdas to be executed when
# requested.
return [regularizer() for regularizer in self._losses]
else:
return self._losses
def add_loss(self, losses, inputs=None):
"""Add loss tensor(s), potentially dependent on layer inputs.
Some losses (for instance, activity regularization losses) may be dependent
on the inputs passed when calling a layer. Hence, when reusing the same
layer on different inputs `a` and `b`, some entries in `layer.losses` may
be dependent on `a` and some on `b`. This method automatically keeps track
of dependencies.
The `get_losses_for` method allows to retrieve the losses relevant to a
specific set of inputs.
Note that `add_loss` is not supported when executing eagerly. Instead,
variable regularizers may be added through `add_variable`. Activity
regularization is not supported directly (but such losses may be returned
from `Layer.call()`).
Arguments:
losses: Loss tensor, or list/tuple of tensors.
inputs: If anything other than None is passed, it signals the losses
are conditional on some of the layer's inputs,
and thus they should only be run where these inputs are available.
This is the case for activity regularization losses, for instance.
If `None` is passed, the losses are assumed
to be unconditional, and will apply across all dataflows of the layer
(e.g. weight regularization losses).
Raises:
RuntimeError: If called in Eager mode.
"""
if context.executing_eagerly():
# TODO(fchollet): it should be possible (and highly desirable) to support
# `add_loss` in eager mode. This allows great convenience and flexibility
# in defining custom losses on the fly (e.g. in VAEs).
# Simply appending the loss value to `self._losses`
# is the correct behavior.
# The only caveat is that we need to force the user to only call
# `add_loss` from inside a model or Layer's `call` method
# (otherwise the loss computation cannot be backproped through).
raise RuntimeError('Layer.add_loss not supported in Eager mode.')
losses = generic_utils.to_list(losses)
losses = [ops.convert_to_tensor(loss, dtype=backend.floatx())
if not tensor_util.is_tensor(loss) else loss for loss in losses]
self._losses += losses
if inputs is None:
for loss in losses:
loss._unconditional_loss = True # pylint: disable=protected-access
else:
for loss in losses:
loss._unconditional_loss = False # pylint: disable=protected-access
def get_losses_for(self, inputs):
"""Retrieves losses relevant to a specific set of inputs.
Arguments:
inputs: Input tensor or list/tuple of input tensors.
Returns:
List of loss tensors of the layer that depend on `inputs`.
Raises:
RuntimeError: If called in Eager mode.
"""
if context.executing_eagerly():
raise RuntimeError('Layer.get_losses_for not supported in Eager mode.')
if inputs is None:
# Requesting unconditional losses.
return [x for x in self.losses if x._unconditional_loss] # pylint: disable=protected-access
# Requesting input-conditional losses.
inputs = nest.flatten(inputs)
# Retrieve the set of tensors in the TF graph that depend on `inputs`.
# The losses we want to return will be part of this set.
# To avoid unnecessary work, we stop the search in case all of
# `self.losses` have been retrieved.
reachable = tf_utils.get_reachable_from_inputs(inputs, self.losses)
losses = []
for loss in self.losses:
if loss in reachable:
losses.append(loss)
return losses
def _name_scope(self):
return self.name
def build(self, input_shape):
"""Creates the variables of the layer."""
self.built = True
def add_variable(self, *args, **kwargs):
"""Alias for `add_weight`."""
return self.add_weight(*args, **kwargs)
def add_weight(self,
name,
shape,
dtype=None,
initializer=None,
regularizer=None,
trainable=None,
constraint=None,
partitioner=None,
use_resource=None,
synchronization=vs.VariableSynchronization.AUTO,
aggregation=vs.VariableAggregation.NONE,
getter=None):
"""Adds a new variable to the layer, or gets an existing one; returns it.
Arguments:
name: variable name.
shape: variable shape.
dtype: The type of the variable. Defaults to `self.dtype` or `float32`.
initializer: initializer instance (callable).
regularizer: regularizer instance (callable).
trainable: whether the variable should be part of the layer's
"trainable_variables" (e.g. variables, biases)
or "non_trainable_variables" (e.g. BatchNorm mean, stddev).
Note, if the current variable scope is marked as non-trainable
then this parameter is ignored and any added variables are also
marked as non-trainable. `trainable` defaults to `True` unless
`synchronization` is set to `ON_READ`.
constraint: constraint instance (callable).
partitioner: Partitioner to be passed to the `Checkpointable` API.
use_resource: Whether to use `ResourceVariable`.
synchronization: Indicates when a distributed a variable will be
aggregated. Accepted values are constants defined in the class
@{tf.VariableSynchronization}. By default the synchronization is set to
`AUTO` and the current `DistributionStrategy` chooses
when to synchronize. If `synchronization` is set to `ON_READ`,
`trainable` must not be set to `True`.
aggregation: Indicates how a distributed variable will be aggregated.
Accepted values are constants defined in the class
@{tf.VariableAggregation}.
getter: Variable getter argument to be passed to the `Checkpointable` API.
Returns:
The created variable. Usually either a `Variable` or `ResourceVariable`
instance. If `partitioner` is not `None`, a `PartitionedVariable`
instance is returned.
Raises:
RuntimeError: If called with partioned variable regularization and
eager execution is enabled.
ValueError: When giving unsupported dtype and no initializer or when
trainable has been set to True with synchronization set as `ON_READ`.
"""
if dtype is None:
dtype = self.dtype or backend.floatx()
dtype = dtypes.as_dtype(dtype)
initializer = initializers.get(initializer)
regularizer = regularizers.get(regularizer)
constraint = constraints.get(constraint)
if synchronization == vs.VariableSynchronization.ON_READ:
if trainable:
raise ValueError(
'Synchronization value can be set to '
'VariableSynchronization.ON_READ only for non-trainable variables. '
'You have specified trainable=True and '
'synchronization=VariableSynchronization.ON_READ.')
else:
# Set trainable to be false when variable is to be synced on read.
trainable = False
elif trainable is None:
trainable = True
# Initialize variable when no initializer provided
if initializer is None:
# If dtype is DT_FLOAT, provide a uniform unit scaling initializer
if dtype.is_floating:
initializer = initializers.glorot_uniform()
# If dtype is DT_INT/DT_UINT, provide a default value `zero`
# If dtype is DT_BOOL, provide a default value `FALSE`
elif dtype.is_integer or dtype.is_unsigned or dtype.is_bool:
initializer = initializers.zeros()
# NOTES:Do we need to support for handling DT_STRING and DT_COMPLEX here?
else:
raise ValueError('An initializer for variable %s of type %s is required'
' for layer %s' % (name, dtype.base_dtype, self.name))
variable = self._add_variable_with_custom_getter(
name=name,
shape=shape,
# TODO(allenl): a `make_variable` equivalent should be added as a
# `Checkpointable` method.
getter=getter or make_variable,
# Manage errors in Layer rather than Checkpointable.
overwrite=True,
initializer=initializer,
dtype=dtype,
constraint=constraint,
trainable=trainable and self.trainable,
partitioner=partitioner,
use_resource=use_resource,
synchronization=synchronization,
aggregation=aggregation)
if regularizer is not None:
# TODO(fchollet): in the future, this should be handled at the
# level of variable creation, and weight regularization losses
# should be variable attributes.
self._handle_weight_regularization(name, variable, regularizer)
if trainable:
self._trainable_weights.append(variable)
else:
self._non_trainable_weights.append(variable)
return variable
def _handle_weight_regularization(self, name, variable, regularizer):
# `init_graph` should point to the graph in which variable initialization
# will occur; it should be None if and only if initialization will take
# place in the eager context.
init_graph = None
if not context.executing_eagerly():
default_graph = ops.get_default_graph()
if default_graph.building_function:
with ops.init_scope():
# Retrieve the variables from the graph into which variables
# will be lifted; if initialization ops will be lifted into
# the eager context, then there is nothing to retrieve, since variable
# collections are not supported when eager execution is enabled.
if not context.executing_eagerly():
init_graph = ops.get_default_graph()
else:
# Initialization ops will not be lifted out of the default graph.
init_graph = default_graph
if init_graph is not None: # pylint: disable=protected-access
# The variable was created and initialized in a graph.
if regularizer:
if isinstance(variable, tf_variables.PartitionedVariable):
for v in variable:
with ops.colocate_with(v.op):
with ops.name_scope(name + '/Regularizer'):
regularization = regularizer(v)
if regularization is not None:
self.add_loss(regularization)
else:
with ops.colocate_with(variable.op):
with ops.name_scope(name + '/Regularizer'):
regularization = regularizer(variable)
if regularization is not None:
self.add_loss(regularization)
elif regularizer: # initialization took place in an eager context
if isinstance(variable, tf_variables.PartitionedVariable):
raise RuntimeError(
'Partitioned variable regularization is not yet '
'supported when executing eagerly. File a feature request'
'if this is important to you.')
# Save a zero-argument lambda which runs the regularizer on the
# variable, to be executed when `Layer.losses` is requested.
# This makes losses responsive to variable updates when executing
# eagerly.
#
# TODO(akshayka): Do the same for graphs as well, so that losses
# collected in a while_loop can be run outside its control flow
# context and so that losses won't be swallowed up by graph functions
# (i.e., `.losses()` should always create regularizers).
self._losses.append(lambda: regularizer(variable))
def _handle_activity_regularization(self, inputs, outputs):
# Apply activity regularization.
# Note that it should be applied every time the layer creates a new
# output, since it is output-specific.
if self._activity_regularizer:
output_list = nest.flatten(outputs)
for output in output_list:
with ops.name_scope('ActivityRegularizer'):
activity_regularization = self._activity_regularizer(output)
self.add_loss(activity_regularization, inputs=inputs)
def call(self, inputs, **kwargs): # pylint: disable=unused-argument
"""This is where the layer's logic lives.
Arguments:
inputs: Input tensor, or list/tuple of input tensors.
**kwargs: Additional keyword arguments.
Returns:
A tensor or list/tuple of tensors.
"""
return inputs
def __call__(self, inputs, *args, **kwargs):
"""Wraps `call`, applying pre- and post-processing steps.
Arguments:
inputs: input tensor(s).
*args: additional positional arguments to be passed to `self.call`.
**kwargs: additional keyword arguments to be passed to `self.call`.
Returns:
Output tensor(s).
Note:
- The following optional keyword arguments are reserved for specific uses:
* `training`: Boolean scalar tensor of Python boolean indicating
whether the `call` is meant for training or inference.
* `mask`: Boolean input mask.
- If the layer's `call` method takes a `mask` argument (as some Keras
layers do), its default value will be set to the mask generated
for `inputs` by the previous layer (if `input` did come from
a layer that generated a corresponding mask, i.e. if it came from
a Keras layer with masking support.
Raises:
ValueError: if the layer's `call` method returns None (an invalid value).
"""
input_list = nest.flatten(inputs)
build_graph = not context.executing_eagerly()
# TODO(fchollet, allenl): Make deferred mode work with subclassed Models
# which don't use an "inputs" argument.
in_deferred_mode = isinstance(input_list[0], DeferredTensor)
# Handle Keras mask propagation from previous layer to current layer.
previous_mask = None
if build_graph and (not hasattr(self, '_compute_previous_mask') or
self._compute_previous_mask):
previous_mask = collect_previous_mask(inputs)
if not hasattr(self, '_call_fn_args'):
self._call_fn_args = self._no_dependency(
function_utils.fn_args(self.call))
if ('mask' in self._call_fn_args and 'mask' not in kwargs and
not generic_utils.is_all_none(previous_mask)):
# The previous layer generated a mask, and mask was not explicitly pass
# to __call__, hence we set previous_mask as the default value.
kwargs['mask'] = previous_mask
input_shapes = None
with ops.name_scope(self._name_scope()):
if not self.built:
if not build_graph:
# Activity regularization is currently unsupported in Eager mode.
if self._activity_regularizer:
raise ValueError(
'activity_regularizer currently unsupported with '
'eager execution enabled. Found an activity_regularizer in '
'%s(%s).' % (self.__class__.__name__, self))
if not build_graph and not in_deferred_mode:
for x in input_list:
if hasattr(x, '_keras_history'):
raise ValueError('_keras_history currently unsupported in '
'Eager mode. Found _keras_history in %s while '
'executing __call__ for %s(%s)' %
(x, self.__class_.__name__, self))
# Check input assumptions set before layer building, e.g. input rank.
self._assert_input_compatibility(inputs)
if input_list and self._dtype is None:
try:
self._dtype = input_list[0].dtype.base_dtype.name
except AttributeError:
pass
if all(hasattr(x, 'shape') for x in input_list):
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
self.build(input_shapes)
self.built = True
# Check input assumptions set after layer building, e.g. input shape.
if build_graph or in_deferred_mode:
self._assert_input_compatibility(inputs)
if not in_deferred_mode:
outputs = self.call(inputs, *args, **kwargs)
if outputs is None:
raise ValueError('A layer\'s `call` method should return a Tensor '
'or a list of Tensors, not None (layer: ' +
self.name + ').')
else:
# Deferred mode behavior: use `compute_output_shape` to
# infer the number of outputs of the layer and their shapes.
if input_shapes is None:
input_shapes = nest.map_structure(lambda x: x.shape, inputs)
output_shapes = self.compute_output_shape(input_shapes)
output_shapes = nest.flatten(output_shapes)
outputs = [
# TODO(fchollet): name the deferred tensors?
DeferredTensor(shape=shape, dtype=self._dtype)
for shape in output_shapes
]
if len(outputs) == 1:
outputs = outputs[0]
if build_graph:
self._handle_activity_regularization(inputs, outputs)
# TODO(fchollet): consider enabling masking for Eager mode.
self._set_mask_metadata(inputs, outputs, previous_mask)
if in_deferred_mode or build_graph and have_all_keras_metadata(inputs):
inputs, outputs = self._set_connectivity_metadata_(
inputs, outputs, args, kwargs)
if context.executing_eagerly():
return outputs
if hasattr(self, '_symbolic_set_inputs') and not self.inputs:
# Subclassed network: explicitly set metadata normally set by a call to
# self._set_inputs(). This is not relevant in eager execution.
self._symbolic_set_inputs(inputs, outputs)
if in_deferred_mode or build_graph:
self._set_learning_phase_metadata(inputs, outputs)
# Optionally load weight values that were specified at layer instantiation.
# TODO(fchollet): consider enabling this with eager execution too.
if hasattr(self, '_initial_weights') and self._initial_weights is not None:
self.set_weights(self._initial_weights)
del self._initial_weights
self._post_build_cleanup()
return outputs
def _post_build_cleanup(self):
"""Hooks to run after all sub-Layers are built."""
# Note that in addition to Layer.__call__, this method is called by Model
# after building a graph network (which skips __call__). It should be called
# when possible if self.built may have switched from False to True, and is
# idempotent.
pass # No-op for Layers which don't override this method.
def apply(self, inputs, *args, **kwargs):
"""Apply the layer on a input.
This simply wraps `self.__call__`.
Arguments:
inputs: Input tensor(s).
*args: additional positional arguments to be passed to `self.call`.
**kwargs: additional keyword arguments to be passed to `self.call`.
Returns:
Output tensor(s).
"""
return self.__call__(inputs, *args, **kwargs)
def _set_learning_phase_metadata(self, inputs, outputs):
# Update learning phase info. To work with subclassed models,
# this should be done even if Keras metadata is absent.
output_tensors = generic_utils.to_list(outputs)
uses_lp = any(
[getattr(x, '_uses_learning_phase', False)
for x in generic_utils.to_list(inputs)])
uses_lp = getattr(self, 'uses_learning_phase', False) or uses_lp
for i in range(len(output_tensors)):
try:
output_tensors[i]._uses_learning_phase = getattr(
output_tensors[i], '_uses_learning_phase', False) or uses_lp
except AttributeError:
# An output element happens to be a C type (such as tuple or dict).
# We don't track learning phase info in such edge cases.
pass
def _set_mask_metadata(self, inputs, outputs, previous_mask):
if hasattr(self, 'compute_mask'):
output_mask = self.compute_mask(inputs, previous_mask)
if isinstance(outputs, (list, tuple)):
if output_mask is None:
output_mask = [None for _ in range(len(outputs))]
for x, m in zip(outputs, output_mask):
try:
x._keras_mask = m # pylint: disable=protected-access
except AttributeError:
pass # C type such as dict. Masking not supported in this case.
else:
try:
outputs._keras_mask = output_mask # pylint: disable=protected-access
except AttributeError:
pass # C type such as dict. Masking not supported in this case.
def _set_connectivity_metadata_(self, inputs, outputs, args, kwargs):
call_convention = getattr(self, '_call_convention',
CallConvention.EXPLICIT_INPUTS_ARGUMENT)
if args:
if call_convention == CallConvention.EXPLICIT_INPUTS_ARGUMENT:
raise TypeError(
'This Layer takes an `inputs` argument to call(), and only the '
'`inputs` argument may be specified as a positional argument. '
'Pass everything else as a keyword argument (those arguments will'
' not be tracked as inputs to the Layer).')
elif call_convention == CallConvention.SINGLE_POSITIONAL_ARGUMENT:
raise TypeError(
'This Layer takes a single positional argument to call(), which is '
'by convention the inputs argument, and only this argument may be '
'specified as a positional argument. Pass everything else as a '
'keyword argument (those arguments will not be tracked as inputs '
'to the Layer).')
# If the layer returns tensors from its inputs, unmodified,
# we copy them to avoid loss of tensor metadata.
output_ls = nest.flatten(outputs)
output_ls_copy = []
for x in output_ls:
if x in nest.flatten(inputs):
with ops.name_scope(self.name):
x = array_ops.identity(x)
output_ls_copy.append(x)
if len(output_ls_copy) == 1:
outputs = output_ls_copy[0]
else:
outputs = output_ls_copy
inputs, kwargs = self._inputs_from_call_args(
call_args=(inputs,) + args, call_kwargs=kwargs)
# Add an inbound node to the layer, so it can keep track of this call.
# This updates the layer history of the output tensor(s).
kwargs.pop('mask', None) # `mask` should not be serialized.
self._add_inbound_node(
input_tensors=inputs, output_tensors=outputs, arguments=kwargs)
return inputs, outputs
def _inputs_from_call_args(self, call_args, call_kwargs):
"""Get Layer inputs from __call__ *args and **kwargs.
Args:
call_args: The positional arguments passed to __call__.
call_kwargs: The keyword argument dict passed to __call__.
Returns:
A tuple of (inputs, non_input_kwargs). These may be the same objects as
were passed in (call_args and call_kwargs).
"""
call_convention = getattr(self, '_call_convention',
CallConvention.EXPLICIT_INPUTS_ARGUMENT)
if (call_convention in (
CallConvention.EXPLICIT_INPUTS_ARGUMENT,
CallConvention.SINGLE_POSITIONAL_ARGUMENT)):
assert len(call_args) == 1 # TypeError raised earlier in __call__.
return call_args[0], call_kwargs
else:
call_arg_spec = tf_inspect.getargspec(self.call)
# There is no explicit "inputs" argument expected or provided to
# call(). Arguments which have default values are considered non-inputs,
# and arguments without are considered inputs.
if call_arg_spec.defaults:
if call_arg_spec.varargs is not None:
raise TypeError(
'Layer.call() may not accept both *args and arguments with '
'default values (unable to determine which are inputs to the '
'Layer).')
keyword_arg_names = set(
call_arg_spec.args[-len(call_arg_spec.defaults):])
else:
keyword_arg_names = set()
# Training is never an input argument name, to allow signatures like
# call(x, training).
keyword_arg_names.add('training')
_, unwrapped_call = tf_decorator.unwrap(self.call)
bound_args = inspect.getcallargs(
unwrapped_call, *call_args, **call_kwargs)
if call_arg_spec.keywords is not None:
var_kwargs = bound_args.pop(call_arg_spec.keywords)
bound_args.update(var_kwargs)
keyword_arg_names = keyword_arg_names.union(var_kwargs.keys())
all_args = call_arg_spec.args
if all_args and bound_args[all_args[0]] is self:
# Ignore the 'self' argument of methods
bound_args.pop(call_arg_spec.args[0])
all_args = all_args[1:]
non_input_arg_values = {}
input_arg_values = []
remaining_args_are_keyword = False
for argument_name in all_args:
if argument_name in keyword_arg_names:
remaining_args_are_keyword = True
else:
if remaining_args_are_keyword:
raise TypeError(
'Found a positional argument to call() after a non-input '
'argument. All arguments after "training" must be keyword '
'arguments, and are not tracked as inputs to the Layer.')
if remaining_args_are_keyword:
non_input_arg_values[argument_name] = bound_args[argument_name]
else:
input_arg_values.append(bound_args[argument_name])
if call_arg_spec.varargs is not None:
input_arg_values.extend(bound_args[call_arg_spec.varargs])
return input_arg_values, non_input_arg_values
def compute_output_shape(self, input_shape):
"""Computes the output shape of the layer.
Assumes that the layer will be built
to match that input shape provided.
Arguments:
input_shape: Shape tuple (tuple of integers)
or list of shape tuples (one per output tensor of the layer).
Shape tuples can include None for free dimensions,
instead of an integer.
Returns:
An input shape tuple.
"""
raise NotImplementedError
def compute_mask(self, inputs, mask=None): # pylint: disable=unused-argument
"""Computes an output mask tensor.
Arguments:
inputs: Tensor or list of tensors.
mask: Tensor or list of tensors.
Returns:
None or a tensor (or list of tensors,
one per output tensor of the layer).
"""
if not self.supports_masking:
if mask is not None:
if isinstance(mask, list):
if any(m is not None for m in mask):
raise TypeError('Layer ' + self.name + ' does not support masking, '
'but was passed an input_mask: ' + str(mask))
else:
raise TypeError('Layer ' + self.name + ' does not support masking, '
'but was passed an input_mask: ' + str(mask))
# masking not explicitly supported: return None as mask
return None
# if masking is explicitly supported, by default
# carry over the input mask
return mask
def _add_inbound_node(self,
input_tensors,
output_tensors,
arguments=None):
"""Internal method to create an inbound node for the layer.
Arguments:
input_tensors: list of input tensors.