/
lstm_online.py
43 lines (33 loc) · 1.46 KB
/
lstm_online.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
import logging
from keras.optimizers import RMSprop
from keras.utils import plot_model
from pyfiction.agents.ssaqn_agent import SSAQNAgent
from pyfiction.simulators.games.transit_simulator import TransitSimulator
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
"""
An example SSAQN agent for Transit that uses online learning and prioritized sampling
"""
# Create the agent and specify maximum lengths of descriptions (in words)
agent = SSAQNAgent(train_simulators=TransitSimulator())
# Learn the vocabulary (the function samples the game using a random policy)
agent.initialize_tokens('vocabulary.txt')
optimizer = RMSprop(lr=0.001)
embedding_dimensions = 16
lstm_dimensions = 32
dense_dimensions = 8
agent.create_model(embedding_dimensions=embedding_dimensions,
lstm_dimensions=lstm_dimensions,
dense_dimensions=dense_dimensions,
optimizer=optimizer)
# Visualize the model
try:
plot_model(agent.model, to_file='model.png', show_shapes=True)
except ImportError as e:
logger.warning("Couldn't print the model image: {}".format(e))
# Iteratively train the agent on a batch of previously seen examples while continuously expanding the experience buffer
# This example seems to converge to the optimal reward of 19.X
epochs = 1
for i in range(epochs):
logger.info('Epoch %s', i)
agent.train_online(episodes=256, batch_size=64, gamma=0.95, epsilon_decay=0.99, prioritized_fraction=0.25)