-
Notifications
You must be signed in to change notification settings - Fork 0
/
plot.py
93 lines (68 loc) · 2.72 KB
/
plot.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
import matplotlib.pyplot as plt
import pickle
import numpy as np
import os
import argparse
def main():
args = get_parser()
log_dirs = args.log_dirs
parent_path = os.getcwd()
fig, axs = plt.subplots(1, 3, figsize=(30,8), sharey=True)
index = 0
gap_mins = [0.0015, 0.003, 0.009]
for log_dir in log_dirs:
os.chdir(log_dir)
all_R_algO = []
all_R_algP = []
ax = axs[index]
for d in os.listdir():
print(d)
with open(d, 'rb') as f:
data = pickle.load(f)
info = 'S{}-A{}-H{}-GapMin{}'.format(data['S'], data['A'], data['H'], np.around(data['gap_min'], 3))
iters = []
R_algO = []
R_algP = []
for item in data['results']:
k, rO, rP = item
iters.append(k)
R_algO.append(rO)
R_algP.append(rP)
all_R_algO.append(R_algO)
all_R_algP.append(R_algP)
print(len(R_algO))
all_R_algO = np.array(all_R_algO)
all_R_algP = np.array(all_R_algP)
R_algO_mean = np.mean(all_R_algO, axis=0)
R_algP_mean = np.mean(all_R_algP, axis=0)
R_algO_std = np.std(all_R_algO, axis=0)
R_algP_std = np.std(all_R_algP, axis=0)
ax.plot(iters, R_algO_mean, color='r', label='Regret of AlgO')
ax.plot(iters, R_algP_mean, color='b', label='Regret of AlgP')
ax.fill_between(iters, y1=R_algO_mean-2 * R_algO_std, y2=R_algO_mean+2 * R_algO_std, color='r', alpha=0.25)
ax.fill_between(iters, y1=R_algP_mean-2 * R_algP_std, y2=R_algP_mean+2 * R_algP_std, color='b', alpha=0.25)
ax.set_ylim([0, 15000])
ax.set_xlim([0, 55000])
ax.tick_params(axis='x', labelsize=30)
ax.tick_params(axis='y', labelsize=30)
new_x_labels = ['0', '2e4', '4e4']
new_y_labels = ['', '2.5e3', '5e3', '7.5e3', '1e4', '1.25e4', '1.5e4']
ax.set_xticklabels(new_x_labels)
ax.set_yticklabels(new_y_labels)
ax.set_title(r'$\Delta_{\min}$' + '={}'.format(gap_mins[index]), fontsize=30)
os.chdir(parent_path)
if index == len(log_dirs) - 1:
ax.legend(fontsize=30)
if index == 0:
ax.set_ylabel('Regret', fontsize=30)
if index == 1:
ax.set_xlabel('Iteration', fontsize=30)
index += 1
plt.savefig('Results.png', bbox_inches='tight')
def get_parser():
parser = argparse.ArgumentParser()
parser.add_argument('--log-dirs', type = str, nargs='+', help='dirs of logs to plot')
args = parser.parse_args()
return args
if __name__ == '__main__':
main()