-
Notifications
You must be signed in to change notification settings - Fork 0
/
ABSE.py
100 lines (84 loc) · 3.87 KB
/
ABSE.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
import numpy as np
from Bins import BinABSE
from utilities import get_u
class ABSE:
def __init__(self, T, beta, L, gamma, sigma=.5):
# Problem parameters
self.T = T
self.beta = beta
self.L = L
# Constants used to tune the policy
self.gamma = gamma
# constants defined based on problem params
self.c_0 = 2 * L
self.k_min = 0
self.k_0 = int(np.ceil(np.log2(T / 2 / np.log(2)) / (1 + 2 * beta)))
# initializing the set of live bins
l = self._get_l(self.k_min)
self.live_bins = [BinABSE(T=T,
depth=self.k_min,
a1=k * 2 ** (-self.k_min),
a2=(k + 1) * 2 ** (-self.k_min),
l=l,
gamma=self.gamma)
for k in range(2 ** self.k_min)]
self.live_bins_edges = dict()
self.live_bins_edges['lower'] = [k * 2 ** (-self.k_min) for k in range(2 ** self.k_min)]
self.live_bins_edges['higher'] = [(k + 1) * 2 ** (-self.k_min) for k in range(2 ** self.k_min)]
self.prev_arm = np.nan
self.prev_live_bin_idx = np.nan
def _get_l(self, depth):
"""get the maximum number of epochs in a bin at a certain depth before getting burst"""
eps = 1e-6
rhs = 2 * self.c_0 * 2 ** (-depth * self.beta)
lower = 1
upper = self.T
if upper < lower:
return lower
if get_u(upper, self.T * 2 ** (-depth)) >= rhs:
return upper
if get_u(lower, self.T * 2 ** (-depth)) <= rhs:
return lower
mid = (upper + lower) / 2
while abs(get_u(mid, self.T * 2 ** (-depth)) - rhs) > eps:
if get_u(mid, self.T * 2 ** (-depth)) < rhs:
upper = mid
else:
lower = mid
mid = (upper + lower) / 2
return int(np.floor(mid))
def _determine_live_bin(self, x):
live_bin_idx = np.searchsorted(self.live_bins_edges['lower'], x, side='right') - 1
if (live_bin_idx >= 0) and (x <= self.live_bins_edges['higher'][live_bin_idx]):
return live_bin_idx
else:
raise ValueError
def get_arm(self, x):
live_bin_idx = self._determine_live_bin(x)
arm = self.live_bins[live_bin_idx].get_arm()
self.prev_arm = arm
self.prev_live_bin_idx = live_bin_idx
return arm
def collect_observation(self, x, y, arm):
bin_ = self.live_bins[self.prev_live_bin_idx]
bin_.collect_observation(x, y, arm)
if (bin_.tau >= bin_.l) and (bin_.depth < self.k_0) and (len(bin_.active_arms) >= 2):
self._burst()
def _burst(self):
bin_ = self.live_bins[self.prev_live_bin_idx]
children = bin_.get_children()
l = self._get_l(bin_.depth + 1)
self.live_bins = self.live_bins[:self.prev_live_bin_idx] \
+ [BinABSE(T=self.T,
depth=bin_.depth + 1,
a1=child[0],
a2=child[1],
l=l,
gamma=self.gamma) for child in children] \
+ self.live_bins[self.prev_live_bin_idx + 1:]
self.live_bins_edges['lower'] = self.live_bins_edges['lower'][:self.prev_live_bin_idx] \
+ [child[0] for child in children] \
+ self.live_bins_edges['lower'][self.prev_live_bin_idx + 1:]
self.live_bins_edges['higher'] = self.live_bins_edges['higher'][:self.prev_live_bin_idx] \
+ [child[1] for child in children] \
+ self.live_bins_edges['higher'][self.prev_live_bin_idx + 1:]