-
Notifications
You must be signed in to change notification settings - Fork 2
/
optim_alg.py
155 lines (137 loc) · 8.7 KB
/
optim_alg.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import os
import torch
import mspbe
import warnings
import numpy as np
import pandas as pd
from torch.utils import data
class stoc_var_reduce_alg:
def __init__(self, sigma_theta, sigma_omega, num_epoch, saving_dir_path=None, num_checks=20, use_gpu=True, batch_size=1, record_per_dataset_pass=False, batch_svrg_init_ratio=1, batch_svrg_increment_ratio=1, num_workers=0, grid_search=False, terminate_if_less_than_epsilon=False, policy_eval_epsilon=1e-2, rho_multiplier=0, inner_loop_multiplier=1, record_per_epoch=False, method=None, name=None, rho=0, rho_ac=0, rho_omega=0, record_before_one_pass=False, parsing_feature=False):
self.num_epoch = num_epoch
self.check_pt_vals = []
self.num_checks = num_checks
self.saving_dir_path = saving_dir_path
self.sigma_theta = sigma_theta
self.sigma_omega = sigma_omega
self.use_gpu = use_gpu
self.batch_size = batch_size
self.record_per_dataset_pass = record_per_dataset_pass
self.batch_svrg_init_ratio = batch_svrg_init_ratio
self.batch_svrg_increment_ratio = batch_svrg_increment_ratio
self.num_workers = num_workers
self.grid_search = grid_search
self.terminate_if_less_than_epsilon = terminate_if_less_than_epsilon
self.policy_eval_epsilon = policy_eval_epsilon
self.rho_multiplier = rho_multiplier
self.inner_loop_multiplier = inner_loop_multiplier
self.record_per_epoch = record_per_epoch
self.rho = rho
self.rho_ac = rho_ac
self.rho_omega = rho_omega
self.record_before_one_pass = record_before_one_pass
self.parsing_feature = parsing_feature
def run(self):
with warnings.catch_warnings():
warnings.filterwarnings('error')
try:
return self._run()
except Warning as e:
return {'result':e.args[-1], 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'name': self.name, 'record_per_dataset_pass':self.record_per_dataset_pass}
except ValueError as error:
return {'result': str(error), 'sigma_theta': self.sigma_theta, 'sigma_omega': self.sigma_omega, 'name': self.name, 'record_per_dataset_pass':self.record_per_dataset_pass}
def load_mdp_data(self):
self.device = torch.device('cuda') if torch.cuda.is_available() and self.use_gpu else torch.device('cpu')
mdp_info = pd.read_pickle(os.path.join(self.saving_dir_path, 'mdp_info.pkl'))
for attr in mdp_info:
setattr(self, attr, mdp_info.at[0,attr])
self.nFeatures = int(self.nFeatures)
self.num_data = int(self.num_data)
self.sample_seq = np.random.randint(low=0, high=self.num_data, size=self.num_epoch)
if os.path.exists(os.path.join(self.saving_dir_path, 'A.npy')):
self.A = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'A.npy')), dtype=torch.float64, device=self.device)
self.b = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'b.npy')), dtype=torch.float64, device=self.device)
self.C = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'C.npy')), dtype=torch.float64, device=self.device)
self.C_inv = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'C_inv.npy')), dtype=torch.float64, device=self.device)
self.Phi = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'Phi.npy')), dtype=torch.float64, device=self.device)
self.trans_data = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'Trans_data.npy')), dtype=torch.float64, device=self.device)
elif os.path.exists(os.path.join(self.saving_dir_path, 'data.hdf5')):
f = h5py.File(os.path.join(self.saving_dir_path, 'data.hdf5'), 'r')
self.A = torch.as_tensor(f.get('A')[0], dtype=torch.float64, device=self.device)
self.b = torch.as_tensor(f.get('b')[0], dtype=torch.float64, device=self.device)
self.C = torch.as_tensor(f.get('C')[0], dtype=torch.float64, device=self.device)
self.C_inv = torch.as_tensor(f.get('C_inv')[0], dtype=torch.float64, device=self.device)
self.Phi = torch.as_tensor(f.get('phi')[()], dtype=torch.float64, device=self.device)
self.trans_data = torch.as_tensor(f.get('trans_data')[()], dtype=torch.float64, device=self.device)
def init_alg(self):
if os.path.exists(os.path.join(self.saving_dir_path, 'init_theta.npy')) and os.path.exists(os.path.join(self.saving_dir_path, 'init_omega.npy')):
self.theta = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'init_theta.npy')), dtype=torch.float64, device=self.device)
self.omega = torch.as_tensor(np.load(os.path.join(self.saving_dir_path, 'init_omega.npy')), dtype=torch.float64, device=self.device)
else:
self.theta = torch.zeros([self.nFeatures], dtype=torch.float64, device=self.device)
self.omega = torch.zeros([self.nFeatures], dtype=torch.float64, device=self.device)
self.check_pt = self.num_data if self.num_checks == 0 else int(self.num_epoch / self.num_checks)
if self.rho_multiplier > 0:
#self.rho = torch.mul(mspbe.calc_eig_max_AtCinvA(self), self.rho_multiplier)
#self.rho = torch.tensor(0.01, dtype=torch.float32, device=self.device)
self.rho = self.rho_multiplier
print(self.rho)
if self.record_before_one_pass:
self.record_points_before_one_pass = [0]
self.mspbe_history = torch.unsqueeze(mspbe.calc_mspbe_torch(self, self.rho),0)
self.one_over_num_data = torch.tensor(1 / self.num_data, device=self.device)
#if self.parsing_feature: self.mdp_env = run_pi_lstd.get_pi_env()
def record_value_before_one_pass(self):
mspbe_val = mspbe.calc_mspbe_torch(self, self.rho)
self.mspbe_history = torch.cat((self.mspbe_history, torch.unsqueeze(mspbe_val, 0)))
self.record_points_before_one_pass.append(self.num_grad_eval)
def get_stoc_data(self, batch_A_t, batch_b_t, batch_C_t, batch_t_m, j):
return batch_A_t[j,:,:], batch_b_t[j,:], batch_C_t[j,:,:], batch_t_m[j]
def check_complete_data_pass(self):
if self.num_grad_eval >= self.num_data:
mspbe_val = mspbe.calc_mspbe_torch(self, self.rho)
if self.num_pass % self.check_pt == 0: self.check_values_torch(mspbe_val)
self.mspbe_history = torch.cat((self.mspbe_history,torch.unsqueeze(mspbe_val,0)))
self.num_pass += 1
if self.record_before_one_pass: self.record_points_before_one_pass.append(self.num_grad_eval)
self.num_grad_eval = self.num_grad_eval-self.num_data
def check_values_torch(self, mspbe_history_i):
# if torch.isinf(mspbe_history_i): raise ValueError(
# 'mspbe value at check point is inf.')
if torch.isnan(self.theta).any() or torch.isnan(self.omega).any(): raise ValueError(
'theta or omega become nan.')
# if len(self.check_pt_vals) == self.num_checks and np.all(
# np.diff(np.log(np.array(self.check_pt_vals))) > 0):
# raise ValueError(
# 'mspbe value keeps increasing at check points. ' + 'Reject theta=' + str(
# float(self.sigma_theta)) + ' omega=' + str(float(self.sigma_omega)))
def handle_epoch_result(self ,i, batch_j):
self.batch_result[batch_j] = mspbe.calc_mspbe_torch(self, self.rho)
if i % self.check_pt == 0: self.check_values_torch(float(mspbe.calc_mspbe_torch(self, self.rho)))
def end_of_exp(self):
print(self.name + ' last mspbe is %.5f'%(float(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy())))
if (self.rho+self.rho_omega+self.rho_ac) > 0:
self.result = self.mspbe_history.cpu().numpy()
else:
self.result = np.log10(mspbe.calc_mspbe_torch(self, self.rho).cpu().numpy()) if self.grid_search else np.log10(self.mspbe_history.cpu().numpy())
self.theta = self.theta.cpu().numpy()
self.omega = self.omega.cpu().numpy()
self.delete_attrs()
def delete_attrs(self):
del self.sample_seq
if hasattr(self, 'data_generator'): del self.data_generator
if hasattr(self, 'custom_sampler'): del self.custom_sampler
class mdp_dataset(data.Dataset):
def __init__(self, exp):
self.exp = exp
def __len__(self):
return self.exp.num_data
def __getitem__(self, index):
A_t, b_t, C_t = mspbe.get_stoc_abc_torch(self.exp, index)
return A_t, b_t, C_t, index
class fixed_array_sampler(data.Sampler):
def __init__(self, sample_seq):
self.sample_seq = sample_seq
def __iter__(self):
return iter(self.sample_seq)
def __len__(self):
return len(self.sample_seq)