/
train.py
325 lines (300 loc) · 17.8 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
import os
import time
from tqdm import tqdm
import torch
import torchvision
from torch import nn, optim
from torch.utils.data import DataLoader
from tensorboardX import SummaryWriter
from dataset.config import BF_CONFIG, BF_ACTION_CLASS
from model.model import Anticipation_With_Backbone, Anticipation_Without_Backbone
from dataset.breakfast_dataset import BreakfastDataset, collate_fn_with_backbone, collate_fn_without_backbone
import utils.io as io
from model.utils import Mulit_Loss, WarmUpOptimizer
import argparse
# os.environ['CUDA_VISIBLE_DEVICES'] = '1'
def arg_parse():
parser = argparse.ArgumentParser(description="Anticipation Training.")
# For model
parser.add_argument('--use_dec', action="store_false",
help='use decoder or not: action="store_false')
# For dataset
parser.add_argument('--split_idx', type=int, default=0, choices=[0,1,2,3],
help='dataset splited configuration: default=0')
parser.add_argument('--task', type=str, default='recog_anti', choices=["recog_only", "recog_anti"],
help="which task do you want to conduct [recog_only or recog_anti]")
parser.add_argument('--anti_feat', action="store_true",
help="return anticipation features or not: action='store_true'")
parser.add_argument('--feat_type', type=str, default='offline', choices=["offline", "online"],
help="which type of feature do you want to use [offline or online]")
# For training
parser.add_argument('--ds', '--dataset', type=str, default='breakfast',
help='The dataset you want to train: default=breakfast')
parser.add_argument('--nw', '--num_workers', type=int, default=0,
help='Number of workers used in dataloading: default=0')
parser.add_argument('--bs', '--batch_size', type=int, default=2,
help='the size of minibatch: default=1')
parser.add_argument('--optim', type=str, default='adam',
help='which optimizer to be used: default=adam')
parser.add_argument('--lr', type=float, default=0.0001,
help='learning rate: default=0.0001')
parser.add_argument('--wd', type=float, default=0.0,
help='weight decay: default=0.0')
parser.add_argument('--warmup', action="store_true",
help='warmup stratery: action="store_true"')
parser.add_argument('--e_epoch', type=int, default=300,
help='Number of training epoch: default=300')
parser.add_argument('--s_epoch', type=int, default=0,
help='number of beginning epochs : 0')
# For logging or saving
parser.add_argument('--log_dir', type=str, default='./log',
help='path to save the log data like loss\accuracy... : ./log')
parser.add_argument('--exp_ver', '--e_v', type=str, default='v1',
help='the version of code, will create subdir in log/ && checkpoints/ ')
parser.add_argument('--save_dir', type=str, default='./checkpoints',
help='path to save the checkpoints: ./checkpoints')
parser.add_argument('--print_every', type=int, default=10,
help='number of steps for printing training and validation loss: 10')
parser.add_argument('--save_every', type=int, default=20,
help='number of steps for saving the model parameters: 20')
return parser.parse_args()
def train_model_recog_anti():
# prepare the data
collate_fn = collate_fn_with_backbone if args.feat_type == "online" else collate_fn_without_backbone
train_set = BreakfastDataset(mode="trainval", split_idx=args.split_idx, task=args.task, feat_type=args.feat_type, anti_feat=args.anti_feat, preproc=None, over_write=False)
val_set = BreakfastDataset(mode="test", split_idx=args.split_idx, task=args.task, feat_type=args.feat_type, anti_feat=args.anti_feat, preproc=None, over_write=False, data_aug=False)
train_dataloader = DataLoader(dataset=train_set, batch_size=args.bs, shuffle=True, num_workers=args.nw, collate_fn=collate_fn)
val_dataloader = DataLoader(dataset=val_set, batch_size=args.bs, shuffle=True, num_workers=args.nw, collate_fn=collate_fn)
dataset = {"train": train_set, "val": val_set}
dataloader = {"train": train_dataloader, "val": val_dataloader}
phase_list = ["train", 'val']
print("Preparing data done!!!")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Training on {}.".format(device))
# prepare the model
assert args.use_dec == (args.task=="recog_anti"), f"args.use_dec(={args.use_dec}) should be TRUE if args.task(={args.task}) == recog_anti, vice versa."
if args.feat_type == "online":
model = Anticipation_With_Backbone(use_dec=args.use_dec)
else:
model = Anticipation_Without_Backbone(use_dec=args.use_dec)
model.to(device)
# get the numbers of parameters of the designed model
param_dict = {}
for param in model.named_parameters():
moduler_name = param[0].split('.')[0]
if moduler_name in param_dict.keys():
param_dict[moduler_name] += param[1].numel()
else:
param_dict[moduler_name] = param[1].numel()
for k, v in param_dict.items():
print(f"{k} parameters: {v / 1e6} million.")
print(f"Parameters in total: {sum(param_dict.values()) / 1e6} million.")
# build optimizer && criterion
if args.optim == 'sgd':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=args.wd)
elif args.optim == 'adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
else:
optimizer = optim.RMSprop(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=args.wd)
if BF_CONFIG['pre_norm']:
scheduler = optim.lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.1, patience=5, min_lr=1e-9, verbose=True)
else:
optimizer = WarmUpOptimizer(optimizer, BF_CONFIG['d_input'], BF_CONFIG['lr_factor'], BF_CONFIG['warmup_step'])
# pass
criterion = Mulit_Loss(reduction='sum', eps=BF_CONFIG['eps'])
#set up logger
writer = SummaryWriter(log_dir=args.log_dir + '/' + args.ds + '/' + args.exp_ver)
io.mkdir_if_not_exists(os.path.join(args.save_dir, args.ds, args.exp_ver), recursive=True)
# start training
t1 = time.time()
train_iter_num = 0
for epoch in range(args.s_epoch, args.e_epoch):
loss_list = []
for phase in phase_list:
s_t = time.time()
recog_epoch_loss = 0
anti_epoch_loss = 0
recog_sample_num = 0
anti_sample_num = 0
pre_iter_loss = 0
all_epoch_loss = 0
count = 0
for data in tqdm(dataloader[phase]):
# count += 1;
# if count > 10: break
obs_feat = data[0]
obs_labels = data[1]
obs_pad_num = data[2]
anti_feat = data[3]
anti_labels = data[4]
anti_pad_num = data[5]
data_dir = data[6]
obs_itval_gt = data[7]
anti_itval_gt = data[8]
obs_feat, obs_labels, anti_feat, anti_labels, obs_itval_gt, anti_itval_gt = obs_feat.to(device), obs_labels.to(device), anti_feat.to(device), anti_labels.to(device), obs_itval_gt.to(device), anti_itval_gt.to(device)
if phase == 'train':
model.train()
model.zero_grad()
recog_logits, anti_logits, recog_itval, anti_itval, *attn = model(obs_feat, obs_pad_num, anti_pad_num)
recog_loss, anti_loss, recog_t_loss, anti_t_loss = criterion(recog_logits, anti_logits, recog_itval, anti_itval, obs_labels, anti_labels, obs_itval_gt, anti_itval_gt)
loss = recog_loss * BF_CONFIG["loss_weight"][0] + anti_loss * BF_CONFIG["loss_weight"][1] + recog_t_loss * BF_CONFIG["loss_weight"][2] + anti_t_loss * BF_CONFIG["loss_weight"][3]
loss.backward()
# torch.nn.utils.clip_grad_norm_(model.parameters(), 1)
optimizer.step()
else:
model.eval()
with torch.no_grad():
recog_logits, anti_logits, recog_itval, anti_itval, *attn = model(obs_feat, obs_pad_num, anti_pad_num)
recog_loss, anti_loss, recog_t_loss, anti_t_loss = criterion(recog_logits, anti_logits, recog_itval, anti_itval, obs_labels, anti_labels, obs_itval_gt, anti_itval_gt)
loss = recog_loss * BF_CONFIG["loss_weight"][0] + anti_loss * BF_CONFIG["loss_weight"][1] + recog_t_loss * BF_CONFIG["loss_weight"][2] + anti_t_loss * BF_CONFIG["loss_weight"][3]
# epoch_loss += loss.item() * (labels.shape.numel() - pad_num.sum()).float()
recog_epoch_loss += recog_loss.item()
anti_epoch_loss += anti_loss.item()
all_epoch_loss += loss.item()
recog_sample_num += (obs_labels.shape.numel() - obs_pad_num.sum())
anti_sample_num += (anti_labels.shape.numel() - anti_pad_num.sum())
# plot training loss iteration by tieration
if phase == 'train':
r_loss = recog_loss.item()/(obs_labels.shape.numel() - obs_pad_num.sum())
a_loss = anti_loss.item()/(anti_labels.shape.numel() - anti_pad_num.sum())
i_loss = r_loss + a_loss
writer.add_scalars('train_iter_loss', {'recog': r_loss, 'anti': a_loss, 'all': i_loss}, train_iter_num)
# to log the data which might be the outliers
if epoch > 50 and (i_loss - pre_iter_loss) > 2:
io.mkdir_if_not_exists('./result', recursive=True)
if pre_iter_loss == 0:
io.write('./result/outlier_data.txt', data_dir, 'w')
else:
io.write('./result/outlier_data.txt', data_dir, 'a')
pre_iter_loss = i_loss
train_iter_num += 1
recog_epoch_loss /= recog_sample_num
anti_epoch_loss /= anti_sample_num
loss_list.append([recog_epoch_loss, anti_epoch_loss])
# print loss
if epoch == 0 or (epoch % args.print_every) == args.print_every-1:
e_t = time.time()
print(f"Phase:[{phase}] Epoch:[{epoch+1}/{args.e_epoch}] All_Loss:[{round(all_epoch_loss, 4)}] Recog_Loss:[{round(recog_epoch_loss, 4)}] Anti_Loss:[{round(anti_epoch_loss, 4)}] Execution_time:[{round(e_t-s_t, 1)}] second")
# plot loss
assert len(phase_list) == len(loss_list)
if len(phase_list) == 2:
writer.add_scalars('train_val_epoch_loss', {'train_loss': sum(loss_list[0]), 'train_recog_loss': loss_list[0][0], 'train_anti_loss': loss_list[0][1], \
'val_loss': sum(loss_list[1]), 'val_recog_loss': loss_list[1][0], 'val_anti_loss': loss_list[1][1]}, epoch)
if BF_CONFIG['pre_norm']:
scheduler.step(sum(loss_list[1]))
else:
writer.add_scalars('trainval_epoch_loss', {'trainval_loss': sum(loss_list[0]), 'recog_loss': loss_list[0][0], 'anti_loss': loss_list[0][1]}, epoch)
# save training information and checkpoint
if epoch % args.save_every == (args.save_every - 1) and epoch >= 0:
opts = {'lr': args.lr, 'b_s': args.bs, 'optim': args.optim, 'use_dec': args.use_dec}
save_info = {"arguments": opts, "config": BF_CONFIG}
io.dumps_json(save_info, os.path.join(args.save_dir, args.ds, args.exp_ver, 'training_info.json'))
save_name = "checkpoint_" + str(epoch+1) + "_epoch.pth"
torch.save(model.state_dict(), os.path.join(args.save_dir, args.ds, args.exp_ver, save_name))
writer.close()
t2 = time.time()
print("Training finished! It takes {} seconds.".format(round(t2-t1, 1)))
def train_model_recog_only():
# prepare the data
collate_fn = collate_fn_with_backbone if args.feat_type == "online" else collate_fn_without_backbone
train_set = BreakfastDataset(mode="train", split_idx=args.split_idx, task=args.task, feat_type=args.feat_type, anti_feat=args.anti_feat, preproc=None, over_write=False)
val_set = BreakfastDataset(mode="val", split_idx=args.split_idx, task=args.task, feat_type=args.feat_type, anti_feat=args.anti_feat, preproc=None, over_write=False, data_aug=False)
train_dataloader = DataLoader(dataset=train_set, batch_size=args.bs, shuffle=True, num_workers=args.nw, collate_fn=collate_fn)
val_dataloader = DataLoader(dataset=val_set, batch_size=args.bs, shuffle=True, num_workers=args.nw, collate_fn=collate_fn)
dataset = {"train": train_set, "val": val_set}
dataloader = {"train": train_dataloader, "val": val_dataloader}
phase_list = ["train", 'val']
print("Preparing data done!!!")
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print("Training on {}!!!".format(device))
# prepare the model
assert args.use_dec == (args.task=="recog_anti"), "args.use_dec should be TRUE if args.task==recog_anti, vice versa"
if args.feat_type == "online":
model = Anticipation_With_Backbone(use_dec=args.use_dec)
else:
model = Anticipation_Without_Backbone(use_dec=args.use_dec)
model.to(device)
# get the numbers of parameters of the designed model
param_dict = {}
for param in model.named_parameters():
moduler_name = param[0].split('.')[0]
if moduler_name in param_dict.keys():
param_dict[moduler_name] += param[1].numel()
else:
param_dict[moduler_name] = param[1].numel()
for k, v in param_dict.items():
print(f"{k} Parameters: {v / 1e6} million.")
print(f"Parameters in total: {sum(param_dict.values()) / 1e6} million.")
# build optimizer && criterion
if args.optim == 'sgd':
optimizer = optim.SGD(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, momentum=0.9, weight_decay=0)
elif args.optim == 'adam':
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=0)
else:
optimizer = optim.Adam(filter(lambda p: p.requires_grad, model.parameters()), lr=args.lr, weight_decay=0, amsgrad=True)
# scheduler = optim.lr_scheduler.StepLR(optimizer, step_size=args.scheduler_step, gamma=0.1)
# for param_group in model.param_groups:
# param_group['lr'] = new_lr
# rec_criterion = nn.BCEWithLogitsLoss()
rec_criterion = nn.CrossEntropyLoss(reduction='sum') # 'mean' 'sum' 'none'
#set up logger
writer = SummaryWriter(log_dir=args.log_dir + '/' + args.ds + '/' + args.exp_ver)
io.mkdir_if_not_exists(os.path.join(args.save_dir, args.ds, args.exp_ver), recursive=True)
# start training
for epoch in range(args.start_epoch, args.epoch):
loss_list = []
for phase in phase_list:
s_t = time.time()
epoch_loss = 0
sample_num = 0
# count = 0
for data in tqdm(dataloader[phase]):
import ipdb; ipdb.set_trace()
# count += 1;
# if count > 10: break
feat = data[0]
labels = data[1]
pad_num = data[2]
feat, labels = feat.to(device), labels.to(device)
if phase == 'train':
model.train()
model.zero_grad()
logits = model(feat, pad_num)
loss = rec_criterion(logits.reshape(logits.shape[:-1].numel(), logits.shape[-1]), labels.reshape(labels.shape.numel()))
loss.backward()
optimizer.step()
else:
model.eval()
with torch.no_grad():
logits = model(feat, pad_num)
loss = rec_criterion(logits.reshape(logits.shape[:-1].numel(), logits.shape[-1]), labels.reshape(labels.shape.numel()))
# epoch_loss += loss.item() * (labels.shape.numel() - pad_num.sum()).float()
epoch_loss += loss.item()
sample_num += (labels.shape.numel() - pad_num.sum()).float()
epoch_loss /= sample_num
loss_list.append(epoch_loss)
# print loss
if epoch == 0 or (epoch % args.print_every) == 9:
e_t = time.time()
print(f"Phase:[{phase}] Epoch:[{epoch+1}/{args.epoch}] Loss:[{epoch_loss}] Execution_time:[{round(e_t-s_t, 1)}] second")
# plot loss
assert len(phase_list) == len(loss_list)
if len(phase_list) == 2:
writer.add_scalars('train_val_loss', {'train': loss_list[0], 'val': loss_list[1]}, epoch)
else:
writer.add_scalars('trainval_loss', {'trainval': loss_list[0]}, epoch)
# save training information and checkpoint
if epoch % args.save_every == (args.save_every - 1) and epoch >= 0:
opts = {'lr': args.lr, 'b_s': args.bs, 'optim': args.optim, 'use_dec': args.use_dec}
save_info = {"arguments": opts, "config": BF_CONFIG, 'parameters': param_dict}
io.dumps_json(save_info, os.path.join(args.save_dir, args.ds, args.exp_ver, 'training_info.json'))
save_name = "checkpoint_" + str(epoch+1) + "_epoch.pth"
torch.save(model.state_dict(), os.path.join(args.save_dir, args.ds, args.exp_ver, save_name))
writer.close()
print("Training finished!!!")
if __name__ == "__main__":
args = arg_parse()
if args.task == "recog_only":
train_model_recog_only()
else:
train_model_recog_anti()