Skip to content

Commit d441f7f

Browse files
committed
Add LSTM loop to exploration script
1 parent 4b14cfa commit d441f7f

File tree

1 file changed

+46
-0
lines changed

1 file changed

+46
-0
lines changed

scripts/exploration.py

Lines changed: 46 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -104,3 +104,49 @@ def gru_rnn_loop(batch, gru: tf.keras.layers.GRU):
104104

105105
# The final output should correspond to the output of the GRU layer.
106106
gru_rnn_loop(batch, gru)
107+
108+
# %%
109+
lstm = tf.keras.layers.LSTM(2, use_bias=False)
110+
res = lstm(batch)
111+
print(f"LSTM kernel: {lstm.weights[0]}")
112+
print(f"LSTM recurrent kernel: {lstm.weights[1]}")
113+
print()
114+
print(f"LSTM output: {res}")
115+
116+
# %%
117+
def lstm_rnn_loop(batch, lstm: tf.keras.layers.GRU):
118+
# Code is inspired by the LSTMCell call():
119+
# https://github.com/tensorflow/tensorflow/blob/a4dfb8d1a71385bd6d122e4f27f86dcebb96712d/tensorflow/python/keras/layers/recurrent.py#L2414-L2472
120+
121+
# Previous and carry states
122+
h, c = lstm.get_initial_state(batch)
123+
kernel = lstm.weights[0]
124+
recurrent_kernel = lstm.weights[1]
125+
126+
for timestep in range(batch.shape[1]):
127+
inp = batch[:, timestep, :]
128+
129+
z = tf.keras.backend.dot(inp, kernel)
130+
z += tf.keras.backend.dot(h, recurrent_kernel)
131+
132+
z = tf.split(z, num_or_size_splits=4, axis=1)
133+
134+
z0, z1, z2, z3 = z
135+
i = lstm.recurrent_activation(z0)
136+
f = lstm.recurrent_activation(z1)
137+
c = f * c + i * lstm.activation(z2)
138+
o = lstm.recurrent_activation(z3)
139+
140+
h = o * lstm.activation(c)
141+
o = h
142+
143+
print(f"Input timestep: {inp}")
144+
print(f"Gates: z: {z}")
145+
print(f"Next state: c: {c}, h: {h}")
146+
print()
147+
148+
print(f"Final output: {o}")
149+
150+
151+
lstm_rnn_loop(batch, lstm)
152+
# %%

0 commit comments

Comments
 (0)