@@ -104,3 +104,49 @@ def gru_rnn_loop(batch, gru: tf.keras.layers.GRU):
104
104
105
105
# The final output should correspond to the output of the GRU layer.
106
106
gru_rnn_loop (batch , gru )
107
+
108
+ # %%
109
+ lstm = tf .keras .layers .LSTM (2 , use_bias = False )
110
+ res = lstm (batch )
111
+ print (f"LSTM kernel: { lstm .weights [0 ]} " )
112
+ print (f"LSTM recurrent kernel: { lstm .weights [1 ]} " )
113
+ print ()
114
+ print (f"LSTM output: { res } " )
115
+
116
+ # %%
117
+ def lstm_rnn_loop (batch , lstm : tf .keras .layers .GRU ):
118
+ # Code is inspired by the LSTMCell call():
119
+ # https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/keras/layers/recurrent.py#L2414-L2472
120
+
121
+ # Previous and carry states
122
+ h , c = lstm .get_initial_state (batch )
123
+ kernel = lstm .weights [0 ]
124
+ recurrent_kernel = lstm .weights [1 ]
125
+
126
+ for timestep in range (batch .shape [1 ]):
127
+ inp = batch [:, timestep , :]
128
+
129
+ z = tf .keras .backend .dot (inp , kernel )
130
+ z += tf .keras .backend .dot (h , recurrent_kernel )
131
+
132
+ z = tf .split (z , num_or_size_splits = 4 , axis = 1 )
133
+
134
+ z0 , z1 , z2 , z3 = z
135
+ i = lstm .recurrent_activation (z0 )
136
+ f = lstm .recurrent_activation (z1 )
137
+ c = f * c + i * lstm .activation (z2 )
138
+ o = lstm .recurrent_activation (z3 )
139
+
140
+ h = o * lstm .activation (c )
141
+ o = h
142
+
143
+ print (f"Input timestep: { inp } " )
144
+ print (f"Gates: z: { z } " )
145
+ print (f"Next state: c: { c } , h: { h } " )
146
+ print ()
147
+
148
+ print (f"Final output: { o } " )
149
+
150
+
151
+ lstm_rnn_loop (batch , lstm )
152
+ # %%
0 commit comments