Standard layers (like tf.layers.dense
) fail to report their trainable weights, when used from within a custom layer
#8253
Labels
System information
Describe the current behavior
When building custom layers, it is often useful to use "standard" layer types like
tf.layers.dense
andtf.layers.LSTM
, from inside of that layer. However, layers added in this way have 2 major problems:model.summary()
.model.save()
.This is problematic for obvious reasons. The alternative is to use the
this.addWeight()
API; however, weights added in this way also have problems:this.addWeight()
cannot use string activations, likemish
andswish
.If there is already a supported way to integrate the weights from a standard layer like
tf.layers.dense
, from within a custom model - the method is not clear, from any of the documentation I've seen.Describe the expected behavior
I would expect weights used by the computational graph to be included in the
model.summary()
's "trainable parameters" report. But, they are not.___________________________________________________________________________________________________________________ Layer (type) Input Shape Output shape Param # Receives inputs =================================================================================================================== inp-t0B (InputLayer) [[null,null]] [null,null] 0 ___________________________________________________________________________________________________________________ emb-gza (SharedEmbedding) [[null,null]],[[null,null,2 multiple 5091328 inp-t0B[0][0] mlp-adG[0][0] ___________________________________________________________________________________________________________________ enc-RC2 (SinusoidalPositio [[null,null,256]] [null,null,256] 0 emb-gza[0][0] ___________________________________________________________________________________________________________________ attn-FBz (SelfAttention) [[null,null,256]] [null,null,256] 0 enc-RC2[0][0] ___________________________________________________________________________________________________________________ mlp-3kL (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-FBz[0][0] ___________________________________________________________________________________________________________________ attn-VZK (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-3kL[0][0] ___________________________________________________________________________________________________________________ mlp-Jfy (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-VZK[0][0] ___________________________________________________________________________________________________________________ attn-j0b (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-Jfy[0][0] ___________________________________________________________________________________________________________________ mlp-oyK (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-j0b[0][0] ___________________________________________________________________________________________________________________ attn-L1y (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-oyK[0][0] ___________________________________________________________________________________________________________________ mlp-9r1 (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-L1y[0][0] ___________________________________________________________________________________________________________________ attn-Yha (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-9r1[0][0] ___________________________________________________________________________________________________________________ mlp-GV8 (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-Yha[0][0] ___________________________________________________________________________________________________________________ attn-R5D (SelfAttention) [[null,null,256]] [null,null,256] 0 mlp-GV8[0][0] ___________________________________________________________________________________________________________________ mlp-adG (MultiLayerPercept [[null,null,256]] [null,null,256] 0 attn-R5D[0][0] =================================================================================================================== Total params: 5091328 Trainable params: 5091328 Non-trainable params: 0
Standalone code to reproduce the issue
Add the following custom layer to any model, then call
model.compile()
, thenmodel.summary()
. You will see that it reports 0 trainable parameters:Other info / logs
If there is a supported way to add the trainable parameters from
tf.layers.dense()
to my custom layer, please let me know!The text was updated successfully, but these errors were encountered: