-
Notifications
You must be signed in to change notification settings - Fork 0
/
logger.py
95 lines (79 loc) · 2.64 KB
/
logger.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
import os
from options import args
import torch
from scipy import io
import time
def deletefile(filename):
try:
os.remove(filename)
except BaseException:
pass
class MyLogger:
def __init__(self):
pass
def init(self, save_dir):
self.save_dir = save_dir
if not os.path.exists(self.save_dir):
os.makedirs(self.save_dir)
self.record = {
"data": {},
"step": {},
"info": {},
"args": {},
"output": ""
}
#
self.start_time = time.time()
#
self.filename_py = save_dir + ".pt"
self.filename_mat = save_dir + ".mat"
deletefile(self.filename_py)
deletefile(self.filename_mat)
#
self.log_args()
def log_args(self):
origin_arg_dict = args.__dict__
arg_dict = self.record["args"]
for key in origin_arg_dict.keys():
arg_dict[key] = str(origin_arg_dict[key])
def log_config(self, config):
arg_dict = self.record["args"]
for key in config.keys():
arg_dict[key] = str(config[key])
def add_record(self, key, value, step):
data_dict = self.record["data"]
step_dict = self.record["step"]
if type(value) is torch.Tensor:
value = value.cpu()
value = value.numpy()
if not (key in data_dict.keys()):
data_dict[key] = []
step_dict[key] = []
data_dict[key].append(value)
step_dict[key].append(step)
pass
def add_records(self, data_dict, step):
for key in data_dict.keys():
self.add_record(key, data_dict[key], step)
pass
def addinfo(self, key, value):
info = self.record["info"]
info[key] = value
def save(self):
io.savemat(self.filename_mat, self.record, appendmat=False, do_compression=True)
def print(self, s):
output = self.record["output"]
output = output + str(s) + "\n"
print(s)
def tik(self):
info_dict = self.record["info"]
self.start_time = time.time()
info_dict["start_time"] = time.strftime("%m/%d %H:%M", time.localtime())
def tok(self):
info_dict = self.record["info"]
info_dict["end_time"] = time.strftime("%m/%d %H:%M", time.localtime())
total_time = time.time() - self.start_time
hour = total_time // 3600
minute = (total_time - hour * 3600) // 60
info_dict["total_time"] = f"consumed {hour} hour {minute} minutes"
logger = MyLogger()