Skip to content

Commit

Permalink
Merge pull request #193 from keras-rl/docs
Browse files Browse the repository at this point in the history
add docs on policy.py
  • Loading branch information
RaphaelMeudec committed Apr 8, 2018
2 parents 8c6ce12 + b406621 commit fac1e61
Show file tree
Hide file tree
Showing 4 changed files with 359 additions and 9 deletions.
37 changes: 36 additions & 1 deletion rl/callbacks.py
Expand Up @@ -17,31 +17,39 @@ def _set_env(self, env):
self.env = env

def on_episode_begin(self, episode, logs={}):
"""Called at beginning of each episode"""
pass

def on_episode_end(self, episode, logs={}):
"""Called at end of each episode"""
pass

def on_step_begin(self, step, logs={}):
"""Called at beginning of each step"""
pass

def on_step_end(self, step, logs={}):
"""Called at end of each step"""
pass

def on_action_begin(self, action, logs={}):
"""Called at beginning of each action"""
pass

def on_action_end(self, action, logs={}):
"""Called at end of each action"""
pass


class CallbackList(KerasCallbackList):
def _set_env(self, env):
""" Set environment for each callback in callbackList """
for callback in self.callbacks:
if callable(getattr(callback, '_set_env', None)):
callback._set_env(env)

def on_episode_begin(self, episode, logs={}):
""" Called at beginning of each episode for each callback in callbackList"""
for callback in self.callbacks:
# Check if callback supports the more appropriate `on_episode_begin` callback.
# If not, fall back to `on_epoch_begin` to be compatible with built-in Keras callbacks.
Expand All @@ -51,6 +59,7 @@ def on_episode_begin(self, episode, logs={}):
callback.on_epoch_begin(episode, logs=logs)

def on_episode_end(self, episode, logs={}):
""" Called at end of each episode for each callback in callbackList"""
for callback in self.callbacks:
# Check if callback supports the more appropriate `on_episode_end` callback.
# If not, fall back to `on_epoch_end` to be compatible with built-in Keras callbacks.
Expand All @@ -60,6 +69,7 @@ def on_episode_end(self, episode, logs={}):
callback.on_epoch_end(episode, logs=logs)

def on_step_begin(self, step, logs={}):
""" Called at beginning of each step for each callback in callbackList"""
for callback in self.callbacks:
# Check if callback supports the more appropriate `on_step_begin` callback.
# If not, fall back to `on_batch_begin` to be compatible with built-in Keras callbacks.
Expand All @@ -69,6 +79,7 @@ def on_step_begin(self, step, logs={}):
callback.on_batch_begin(step, logs=logs)

def on_step_end(self, step, logs={}):
""" Called at end of each step for each callback in callbackList"""
for callback in self.callbacks:
# Check if callback supports the more appropriate `on_step_end` callback.
# If not, fall back to `on_batch_end` to be compatible with built-in Keras callbacks.
Expand All @@ -78,21 +89,26 @@ def on_step_end(self, step, logs={}):
callback.on_batch_end(step, logs=logs)

def on_action_begin(self, action, logs={}):
""" Called at beginning of each action for each callback in callbackList"""
for callback in self.callbacks:
if callable(getattr(callback, 'on_action_begin', None)):
callback.on_action_begin(action, logs=logs)

def on_action_end(self, action, logs={}):
""" Called at end of each action for each callback in callbackList"""
for callback in self.callbacks:
if callable(getattr(callback, 'on_action_end', None)):
callback.on_action_end(action, logs=logs)


class TestLogger(Callback):
""" Logger Class for Test """
def on_train_begin(self, logs):
""" Print logs at beginning of training"""
print('Testing for {} episodes ...'.format(self.params['nb_episodes']))

def on_episode_end(self, episode, logs):
""" Print logs at end of each episode """
template = 'Episode {0}: reward: {1:.3f}, steps: {2}'
variables = [
episode + 1,
Expand All @@ -115,22 +131,26 @@ def __init__(self):
self.step = 0

def on_train_begin(self, logs):
""" Print training values at beginning of training """
self.train_start = timeit.default_timer()
self.metrics_names = self.model.metrics_names
print('Training for {} steps ...'.format(self.params['nb_steps']))

def on_train_end(self, logs):
""" Print training time at end of training """
duration = timeit.default_timer() - self.train_start
print('done, took {:.3f} seconds'.format(duration))

def on_episode_begin(self, episode, logs):
""" Reset environment variables at beginning of each episode """
self.episode_start[episode] = timeit.default_timer()
self.observations[episode] = []
self.rewards[episode] = []
self.actions[episode] = []
self.metrics[episode] = []

def on_episode_end(self, episode, logs):
""" Compute and print training statistics of the episode when done """
duration = timeit.default_timer() - self.episode_start[episode]
episode_steps = len(self.observations[episode])

Expand Down Expand Up @@ -183,6 +203,7 @@ def on_episode_end(self, episode, logs):
del self.metrics[episode]

def on_step_end(self, step, logs):
""" Update statistics of episode after each step """
episode = logs['episode']
self.observations[episode].append(logs['observation'])
self.rewards[episode].append(logs['reward'])
Expand All @@ -198,6 +219,7 @@ def __init__(self, interval=10000):
self.reset()

def reset(self):
""" Reset statistics """
self.interval_start = timeit.default_timer()
self.progbar = Progbar(target=self.interval)
self.metrics = []
Expand All @@ -206,15 +228,18 @@ def reset(self):
self.episode_rewards = []

def on_train_begin(self, logs):
""" Initialize training statistics at beginning of training """
self.train_start = timeit.default_timer()
self.metrics_names = self.model.metrics_names
print('Training for {} steps ...'.format(self.params['nb_steps']))

def on_train_end(self, logs):
""" Print training duration at end of training """
duration = timeit.default_timer() - self.train_start
print('done, took {:.3f} seconds'.format(duration))

def on_step_begin(self, step, logs):
""" Print metrics if interval is over """
if self.step % self.interval == 0:
if len(self.episode_rewards) > 0:
metrics = np.array(self.metrics)
Expand All @@ -240,6 +265,7 @@ def on_step_begin(self, step, logs):
print('Interval {} ({} steps performed)'.format(self.step // self.interval + 1, self.step))

def on_step_end(self, step, logs):
""" Update progression bar at the end of each step """
if self.info_names is None:
self.info_names = logs['info'].keys()
values = [('reward', logs['reward'])]
Expand All @@ -253,6 +279,7 @@ def on_step_end(self, step, logs):
self.infos.append([logs['info'][k] for k in self.info_names])

def on_episode_end(self, episode, logs):
""" Update reward value at the end of each episode """
self.episode_rewards.append(logs['episode_reward'])


Expand All @@ -268,20 +295,24 @@ def __init__(self, filepath, interval=None):
self.data = {}

def on_train_begin(self, logs):
""" Initialize model metrics before training """
self.metrics_names = self.model.metrics_names

def on_train_end(self, logs):
""" Save model at the end of training """
self.save_data()

def on_episode_begin(self, episode, logs):
""" Initialize metrics at the beginning of each episode """
assert episode not in self.metrics
assert episode not in self.starts
self.metrics[episode] = []
self.starts[episode] = timeit.default_timer()

def on_episode_end(self, episode, logs):
""" Compute and print metrics at the end of each episode """
duration = timeit.default_timer() - self.starts[episode]

metrics = self.metrics[episode]
if np.isnan(metrics).all():
mean_metrics = np.array([np.nan for _ in self.metrics_names])
Expand All @@ -305,9 +336,11 @@ def on_episode_end(self, episode, logs):
del self.starts[episode]

def on_step_end(self, step, logs):
""" Append metric at the end of each step """
self.metrics[logs['episode']].append(logs['metrics'])

def save_data(self):
""" Save metrics in a json file """
if len(self.data.keys()) == 0:
return

Expand All @@ -329,6 +362,7 @@ def save_data(self):

class Visualizer(Callback):
def on_action_end(self, action, logs):
""" Render environment at the end of each action """
self.env.render(mode='human')


Expand All @@ -341,6 +375,7 @@ def __init__(self, filepath, interval, verbose=0):
self.total_steps = 0

def on_step_end(self, step, logs={}):
""" Save weights at interval steps during training """
self.total_steps += 1
if self.total_steps % self.interval != 0:
# Nothing to do.
Expand Down

0 comments on commit fac1e61

Please sign in to comment.