/
replay_buffer.py
163 lines (141 loc) · 6.88 KB
/
replay_buffer.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
import random
from collections import namedtuple, deque
import numpy as np
from numpy.random import choice
import torch
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
class ReplayBuffer:
"""Fixed-size buffer to store experience tuples."""
def __init__(self, action_size, buffer_size, batch_size, experiences_per_sampling, seed, compute_weights):
"""Initialize a ReplayBuffer object.
Params
======
action_size (int): dimension of each action
buffer_size (int): maximum size of buffer
experiences_per_sampling (int): number of experiences to sample during a sampling iteration
batch_size (int): size of each training batch
seed (int): random seed
"""
self.action_size = action_size
self.buffer_size = buffer_size
self.batch_size = batch_size
self.experiences_per_sampling = experiences_per_sampling
self.alpha = 0.5
self.alpha_decay_rate = 0.99
self.beta = 0.5
self.beta_growth_rate = 1.001
self.seed = random.seed(seed)
self.compute_weights = compute_weights
self.experience_count = 0
self.experience = namedtuple("Experience",
field_names=["state", "action", "reward", "next_state", "done"])
self.data = namedtuple("Data",
field_names=["priority", "probability", "weight","index"])
indexes = []
datas = []
for i in range(buffer_size):
indexes.append(i)
d = self.data(0,0,0,i)
datas.append(d)
self.memory = {key: self.experience for key in indexes}
self.memory_data = {key: data for key,data in zip(indexes, datas)}
self.sampled_batches = []
self.current_batch = 0
self.priorities_sum_alpha = 0
self.priorities_max = 1
self.weights_max = 1
def update_priorities(self, tds, indices):
for td, index in zip(tds, indices):
N = min(self.experience_count, self.buffer_size)
updated_priority = td[0]
if updated_priority > self.priorities_max:
self.priorities_max = updated_priority
if self.compute_weights:
updated_weight = ((N * updated_priority)**(-self.beta))/self.weights_max
if updated_weight > self.weights_max:
self.weights_max = updated_weight
else:
updated_weight = 1
old_priority = self.memory_data[index].priority
self.priorities_sum_alpha += updated_priority**self.alpha - old_priority**self.alpha
updated_probability = td[0]**self.alpha / self.priorities_sum_alpha
data = self.data(updated_priority, updated_probability, updated_weight, index)
self.memory_data[index] = data
def update_memory_sampling(self):
"""Randomly sample X batches of experiences from memory."""
# X is the number of steps before updating memory
self.current_batch = 0
values = list(self.memory_data.values())
random_values = random.choices(self.memory_data,
[data.probability for data in values],
k=self.experiences_per_sampling)
self.sampled_batches = [random_values[i:i + self.batch_size]
for i in range(0, len(random_values), self.batch_size)]
def update_parameters(self):
self.alpha *= self.alpha_decay_rate
self.beta *= self.beta_growth_rate
if self.beta > 1:
self.beta = 1
N = min(self.experience_count, self.buffer_size)
self.priorities_sum_alpha = 0
sum_prob_before = 0
for element in self.memory_data.values():
sum_prob_before += element.probability
self.priorities_sum_alpha += element.priority**self.alpha
sum_prob_after = 0
for element in self.memory_data.values():
probability = element.priority**self.alpha / self.priorities_sum_alpha
sum_prob_after += probability
weight = 1
if self.compute_weights:
weight = ((N * element.probability)**(-self.beta))/self.weights_max
d = self.data(element.priority, probability, weight, element.index)
self.memory_data[element.index] = d
print("sum_prob before", sum_prob_before)
print("sum_prob after : ", sum_prob_after)
def add(self, state, action, reward, next_state, done):
"""Add a new experience to memory."""
self.experience_count += 1
index = self.experience_count % self.buffer_size
if self.experience_count > self.buffer_size:
temp = self.memory_data[index]
self.priorities_sum_alpha -= temp.priority**self.alpha
if temp.priority == self.priorities_max:
self.memory_data[index].priority = 0
self.priorities_max = max(self.memory_data.items(), key=operator.itemgetter(1)).priority
if self.compute_weights:
if temp.weight == self.weights_max:
self.memory_data[index].weight = 0
self.weights_max = max(self.memory_data.items(), key=operator.itemgetter(2)).weight
priority = self.priorities_max
weight = self.weights_max
self.priorities_sum_alpha += priority ** self.alpha
probability = priority ** self.alpha / self.priorities_sum_alpha
e = self.experience(state, action, reward, next_state, done)
self.memory[index] = e
d = self.data(priority, probability, weight, index)
self.memory_data[index] = d
def sample(self):
sampled_batch = self.sampled_batches[self.current_batch]
self.current_batch += 1
experiences = []
weights = []
indices = []
for data in sampled_batch:
experiences.append(self.memory.get(data.index))
weights.append(data.weight)
indices.append(data.index)
states = torch.from_numpy(
np.vstack([e.state for e in experiences if e is not None])).float().to(device)
actions = torch.from_numpy(
np.vstack([e.action for e in experiences if e is not None])).long().to(device)
rewards = torch.from_numpy(
np.vstack([e.reward for e in experiences if e is not None])).float().to(device)
next_states = torch.from_numpy(
np.vstack([e.next_state for e in experiences if e is not None])).float().to(device)
dones = torch.from_numpy(
np.vstack([e.done for e in experiences if e is not None]).astype(np.uint8)).float().to(device)
return (states, actions, rewards, next_states, dones, weights, indices)
def __len__(self):
"""Return the current size of internal memory."""
return len(self.memory)