-
Notifications
You must be signed in to change notification settings - Fork 28
/
q_learning.py
57 lines (49 loc) · 2.1 KB
/
q_learning.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
from pathlib import Path
import numpy as np
import tensorflow as tf
from yarll.agents.agent import Agent
from yarll.environment.environment import Environment
from yarll.policies.e_greedy import EGreedy
from yarll.misc import summary_writer
class QLearning(Agent):
def __init__(self, env: Environment, monitor_path: str, **usercfg) -> None:
super().__init__()
self.env = env
self.monitor_path = Path(monitor_path)
self.config.update(
n_episodes=1000,
gamma=0.99,
alpha=0.5,
epsilon=0.1
)
self.config.update(usercfg)
self.Q_values = np.zeros((self.env.observation_space.n, self.env.action_space.n), dtype=np.float32)
self.policy = EGreedy(self.config["epsilon"])
self.summary_writer = tf.summary.create_file_writer(str(self.monitor_path))
summary_writer.set(self.summary_writer)
def learn(self):
env = self.env
total_steps = 0
summary_writer.start()
for episode in range(self.config["n_episodes"]):
done = False
state = env.reset()
episode_reward = 0
episode_length = 0
while not done:
action, Q_value = self.policy(self.Q_values[state])
new_state, reward, done, _ = env.step(action)
episode_reward += reward
episode_length += 1
total_steps += 1
best_next_action = np.argmax(self.Q_values[new_state])
td_target = reward + self.config["gamma"] * self.Q_values[new_state, best_next_action]
td_delta = td_target - Q_value
self.Q_values[state, action] += self.config["alpha"] * td_delta
if done:
summary_writer.add_scalar("env/reward", episode_reward, total_steps)
summary_writer.add_scalar("env/N_episodes", episode + 1, total_steps)
summary_writer.add_scalar("env/episode_length", episode_length, total_steps)
break
state = new_state
summary_writer.stop()