Skip to content

Commit

Permalink
Various improvements for DQN, the replay memory, and CEM (#31)
Browse files Browse the repository at this point in the history
- Fix various problems with sequential memory
- Fix various problems with CEMAgent
- Remove duplicate forward pass from DQN agent
- Remove performance bottleneck from TrainIntervalLogger
- Add tests for core and memory classes
- Add integration test for CEM
  • Loading branch information
matthiasplappert committed Oct 17, 2016
1 parent befbcf4 commit d449c9e
Show file tree
Hide file tree
Showing 18 changed files with 875 additions and 216 deletions.
6 changes: 5 additions & 1 deletion .travis.yml
Expand Up @@ -12,11 +12,15 @@ matrix:
- python: 2.7
env: KERAS_BACKEND=tensorflow
- python: 2.7
env: KERAS_BACKEND=theano TEST_MODE=PEP8
env: KERAS_BACKEND=tensorflow TEST_MODE=PEP8
- python: 2.7
env: KERAS_BACKEND=theano TEST_MODE=INTEGRATION
- python: 3.4
env: KERAS_BACKEND=theano TEST_MODE=INTEGRATION
- python: 2.7
env: KERAS_BACKEND=tensorflow TEST_MODE=INTEGRATION
- python: 3.4
env: KERAS_BACKEND=tensorflow TEST_MODE=INTEGRATION
install:
# Adopted from https://github.com/fchollet/keras/blob/master/.travis.yml.
- if [[ "$TRAVIS_PYTHON_VERSION" == "2.7" ]]; then
Expand Down
2 changes: 1 addition & 1 deletion examples/cdqn_pendulum.py
Expand Up @@ -62,7 +62,7 @@

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=100000)
memory = SequentialMemory(limit=100000, window_length=1)
random_process = OrnsteinUhlenbeckProcess(theta=.15, mu=0., sigma=.3, size=nb_actions)
agent = ContinuousDQNAgent(nb_actions=nb_actions, V_model=V_model, L_model=L_model, mu_model=mu_model,
memory=memory, nb_steps_warmup=100, random_process=random_process,
Expand Down
10 changes: 5 additions & 5 deletions examples/cem_cartpole.py
Expand Up @@ -21,12 +21,14 @@

# Option 1 : Simple model
model = Sequential()
model.add(Dense(nb_actions,input_dim=obs_dim))
model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
model.add(Dense(nb_actions))
model.add(Activation('softmax'))

# Option 2: deep network
# model = Sequential()
# model.add(Dense(16,input_dim=obs_dim))
# model.add(Flatten(input_shape=(1,) + env.observation_space.shape))
# model.add(Dense(16))
# model.add(Activation('relu'))
# model.add(Dense(16))
# model.add(Activation('relu'))
Expand All @@ -41,7 +43,7 @@

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = EpisodeParameterMemory(limit=1000,max_episode_steps=200)
memory = EpisodeParameterMemory(limit=1000, window_length=1)

cem = CEMAgent(model=model, nb_actions=nb_actions, memory=memory,
batch_size=50, nb_steps_warmup=2000, train_interval=50, elite_frac=0.05)
Expand All @@ -53,8 +55,6 @@
cem.fit(env, nb_steps=100000, visualize=False, verbose=2)

# After training is done, we save the best weights.
print("highest reward total seen : {0}".format(cem.best_seen[0]))
cem.model.set_weights(cem.get_weights_list(cem.best_seen[1]))
cem.save_weights('cem_{}_params.h5f'.format(ENV_NAME), overwrite=True)

# Finally, evaluate our algorithm for 5 episodes.
Expand Down
2 changes: 1 addition & 1 deletion examples/ddpg_pendulum.py
Expand Up @@ -51,7 +51,7 @@

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=100000)
memory = SequentialMemory(limit=100000, window_length=1)
random_process = OrnsteinUhlenbeckProcess(theta=.15, mu=0., sigma=.3)
agent = DDPGAgent(nb_actions=nb_actions, actor=actor, critic=critic, critic_action_input=action_input,
memory=memory, nb_steps_warmup_critic=100, nb_steps_warmup_actor=100,
Expand Down
4 changes: 2 additions & 2 deletions examples/dqn_atari.py
Expand Up @@ -92,7 +92,7 @@ def _step(a):

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=1000000)
memory = SequentialMemory(limit=1000000, window_length=WINDOW_LENGTH)
processor = AtariProcessor()

# Select a policy. We use eps-greedy action selection, which means that a random action is selected
Expand All @@ -109,7 +109,7 @@ def _step(a):
# policy = BoltzmannQPolicy(tau=1.)
# Feel free to give it a try!

dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, window_length=WINDOW_LENGTH, memory=memory,
dqn = DQNAgent(model=model, nb_actions=nb_actions, policy=policy, memory=memory,
processor=processor, nb_steps_warmup=50000, gamma=.99, delta_range=(-1., 1.),
target_model_update=10000, train_interval=4)
dqn.compile(Adam(lr=.00025), metrics=['mae'])
Expand Down
2 changes: 1 addition & 1 deletion examples/dqn_cartpole.py
Expand Up @@ -34,7 +34,7 @@

# Finally, we configure and compile our agent. You can use every built-in Keras optimizer and
# even the metrics!
memory = SequentialMemory(limit=50000)
memory = SequentialMemory(limit=50000, window_length=1)
policy = BoltzmannQPolicy()
dqn = DQNAgent(model=model, nb_actions=nb_actions, memory=memory, nb_steps_warmup=10,
target_model_update=1e-2, policy=policy)
Expand Down
90 changes: 45 additions & 45 deletions rl/agents/cem.py
Expand Up @@ -10,21 +10,19 @@
from rl.util import *

class CEMAgent(Agent):
def __init__(self, model, nb_actions, memory, window_length=1,
batch_size=50, nb_steps_warmup=1000, train_interval=50,
elite_frac=0.05, memory_interval=1, theta_init=None,noise_decay_const=0.0,noise_ampl=0.0):

def __init__(self, model, nb_actions, memory, batch_size=50, nb_steps_warmup=1000,
train_interval=50, elite_frac=0.05, memory_interval=1, theta_init=None,
noise_decay_const=0.0, noise_ampl=0.0, processor=None):
super(CEMAgent, self).__init__()

# Parameters.
self.nb_actions = nb_actions
self.batch_size = batch_size
self.elite_frac = elite_frac
self.num_best = int(self.batch_size*self.elite_frac)
self.num_best = int(self.batch_size * self.elite_frac)
self.nb_steps_warmup = nb_steps_warmup
self.train_interval = train_interval
self.memory_interval = memory_interval
self.window_length = window_length

# if using noisy CEM, the minimum standard deviation will be ampl * exp (- decay_const * step )
self.noise_decay_const = noise_decay_const
Expand All @@ -33,23 +31,23 @@ def __init__(self, model, nb_actions, memory, window_length=1,
# default initial mean & cov, override this by passing an theta_init argument
self.init_mean = 0.0
self.init_stdev = 1.0

self.episode=0

# Related objects.
self.memory = memory

self.model = model
self.processor = processor
self.shapes = [w.shape for w in model.get_weights()]
self.sizes = [w.size for w in model.get_weights()]
self.num_weights = sum(self.sizes)

# store the best result seen during training, as a tuple (reward, flat_weights)
self.best_seen = (None,np.zeros(self.num_weights))
self.best_seen = (-np.inf, np.zeros(self.num_weights))

self.theta = np.zeros(self.num_weights*2)
self.update_theta(theta_init)

# State.
self.episode = 0
self.compiled = False
self.reset_states()

Expand All @@ -66,32 +64,33 @@ def save_weights(self, filepath, overwrite=False):
def get_weights_flat(self,weights):
weights_flat = np.zeros(self.num_weights)

pos=0
pos = 0
for i_layer, size in enumerate(self.sizes):
weights_flat[pos:pos+size] = weights[i_layer].flatten()
pos += size

return weights_flat

def get_weights_list(self,weights_flat):
weights = []
pos=0
pos = 0
for i_layer, size in enumerate(self.sizes):
arr = weights_flat[pos:pos+size].reshape(self.shapes[i_layer])
weights.append(arr)
pos += size
return weights

def reset_states(self):
self.recent_observation = None
self.recent_action = None
self.recent_observations = deque(maxlen=self.window_length)
self.recent_params = deque(maxlen=self.window_length)

def select_action(self,state,stochastic=False):
batch = state.copy()
def select_action(self, state, stochastic=False):
batch = np.array([state])
if self.processor is not None:
batch = self.processor.process_state_batch(batch)

action = self.model.predict_on_batch(batch).flatten()
if (stochastic or self.training):
return np.random.choice(np.arange(self.nb_actions),p=action/np.sum(action))
if stochastic or self.training:
return np.random.choice(np.arange(self.nb_actions), p=np.exp(action) / np.sum(np.exp(action)))
return np.argmax(action)

def update_theta(self,theta):
Expand All @@ -114,60 +113,61 @@ def choose_weights(self):
self.model.set_weights(sampled_weights)

def forward(self, observation):
# Select an action.
if self.processor is not None:
observation = self.processor.process_observation(observation)

while len(self.recent_observations) < self.recent_observations.maxlen:
# Not enough data, fill the recent_observations queue with copies of the current input.
# This allows us to immediately perform a policy action instead of falling back to random
# actions.
self.recent_observations.append(np.copy(observation))
state = np.array(list(self.recent_observations)[1:] + [observation])
# Select an action.
state = self.memory.get_recent_state(observation)
action = self.select_action(state)
if self.processor is not None:
action = self.processor.process_action(action)

# Book-keeping.
self.recent_observations.append(observation)
self.recent_observation = observation
self.recent_action = action
self.recent_params.append(self.get_weights_flat(self.model.get_weights()))

return action

def backward(self, reward, terminal):
# Store most recent experience in memory.
if self.processor is not None:
reward = self.processor.process_reward(reward)
if self.step % self.memory_interval == 0:
self.memory.append(self.recent_observation, self.recent_action, reward, terminal,
training=self.training)

metrics = [np.nan for _ in self.metrics_names]
if not self.training:
# We're done here. No need to update the experience memory since we only use the working
# memory to obtain the state over the most recent observations.
return metrics

# Store most recent experience in memory.
self.memory.append(reward)
if terminal:
params = self.get_weights_flat(self.model.get_weights())
self.memory.finalize_episode(params)

if (terminal):

self.memory.finalise_episode(self.recent_params[-1])
self.episode += 1

if (self.step > self.nb_steps_warmup and self.episode % self.train_interval == 0):
if self.step > self.nb_steps_warmup and self.episode % self.train_interval == 0:
params, reward_totals = self.memory.sample(self.batch_size)
best_idx = np.argsort(np.array(reward_totals))[-self.num_best:]
best = np.vstack([params[i] for i in best_idx])

if (reward_totals[best_idx[-1]] > self.best_seen[0]):
self.best_seen = (reward_totals[best_idx[-1]],params[best_idx[-1]])
if reward_totals[best_idx[-1]] > self.best_seen[0]:
self.best_seen = (reward_totals[best_idx[-1]], params[best_idx[-1]])

metrics = [np.mean(np.array(reward_totals)[best_idx])]

min_std = self.noise_ampl * np.exp(-self.step*self.noise_decay_const)
min_std = self.noise_ampl * np.exp(-self.step * self.noise_decay_const)

mean = np.mean(best,axis=0)
std = np.std(best,axis=0) + min_std
new_theta = np.hstack((mean,std))
mean = np.mean(best, axis=0)
std = np.std(best, axis=0) + min_std
new_theta = np.hstack((mean, std))
self.update_theta(new_theta)

self.choose_weights()

self.episode += 1
return metrics

def _on_train_end(self):
self.model.set_weights(self.get_weights_list(self.best_seen[1]))

@property
def metrics_names(self):
return ['mean_best_reward']

0 comments on commit d449c9e

Please sign in to comment.