-
Notifications
You must be signed in to change notification settings - Fork 2
/
main_agqa_v2.py
113 lines (100 loc) · 3.56 KB
/
main_agqa_v2.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import torch
import torch.nn as nn
import torch.optim as optim
import numpy as np
import random
import os
import logging
from transformers import get_cosine_schedule_with_warmup, DistilBertTokenizer, BertTokenizer
from args import get_args
from model.mist import MIST_VideoQA
from util import compute_a2v
from train.train_mist import train, eval
from data.agqa_clip_patch_loader import get_videoqa_loaders
# args, logging
args = get_args()
if not (os.path.isdir(args.save_dir)):
os.mkdir(os.path.join(args.save_dir))
logging.basicConfig(
level=logging.INFO, format="%(asctime)s %(levelname)-8s %(message)s"
)
logFormatter = logging.Formatter("%(asctime)s %(levelname)-8s %(message)s")
rootLogger = logging.getLogger()
fileHandler = logging.FileHandler(os.path.join(args.save_dir, "stdout.log"), "w+")
fileHandler.setFormatter(logFormatter)
rootLogger.addHandler(fileHandler)
logging.info(args)
# set random seeds
torch.manual_seed(args.seed)
np.random.seed(args.seed)
random.seed(args.seed)
# get answer embeddings
bert_tokenizer = BertTokenizer.from_pretrained("bert-base-uncased")
a2id, id2a, a2v = None, None, None
if not args.mc:
a2id, id2a, a2v = compute_a2v(
vocab_path=args.vocab_path,
bert_tokenizer=bert_tokenizer,
amax_words=args.amax_words,
)
logging.info(f"Length of Answer Vocabulary: {len(a2id)}")
# Model
model = MIST_VideoQA(
feature_dim=args.feature_dim,
word_dim=args.word_dim,
N=args.n_layers,
d_model=args.embd_dim,
d_ff=args.ff_dim,
h=args.n_heads,
dropout=args.dropout,
T=args.max_feats,
Q=args.qmax_words,
baseline=args.baseline,
)
model.cuda()
logging.info("Using {} GPUs".format(torch.cuda.device_count()))
# Load pretrain path
model = nn.DataParallel(model)
if args.pretrain_path != "":
model.load_state_dict(torch.load(args.pretrain_path))
logging.info(f"Loaded checkpoint {args.pretrain_path}")
logging.info(
f"Nb of trainable params:{sum(p.numel() for p in model.parameters() if p.requires_grad)}"
)
# Dataloaders
(
train_loader,
val_loader,
test_loader,
) = get_videoqa_loaders(args, args.features_path, a2id, bert_tokenizer, test_mode=args.test)
logging.info("number of train instances: {}".format(len(train_loader.dataset)))
logging.info("number of val instances: {}".format(len(val_loader.dataset)))
# logging.info("number of test instances: {}".format(len(test_loader.dataset)))
# Loss + Optimizer
criterion = nn.CrossEntropyLoss()
params_for_optimization = list(p for n, p in model.named_parameters() if p.requires_grad and n.split('.')[1] != 'clip')
optimizer = optim.Adam(
params_for_optimization, lr=args.lr, weight_decay=args.weight_decay
)
criterion.cuda()
# Training
if not args.test:
scheduler = get_cosine_schedule_with_warmup(
optimizer, 0, len(train_loader) * args.epochs
)
logging.info(
f"Set cosine schedule with {len(train_loader) * args.epochs} iterations"
)
best_val_acc = -float("inf")
best_epoch = 0
for epoch in range(args.epochs):
best_val_acc, best_epoch = train(model, train_loader, a2v, optimizer, criterion, scheduler, epoch, args,
val_loader, best_val_acc, best_epoch)
logging.info(f"Best val model at epoch {best_epoch + 1}")
model.load_state_dict(
torch.load(
os.path.join(args.checkpoint_predir, args.checkpoint_dir, "best_model.pth")
)
)
# Evaluate on test set
eval(model, test_loader, a2v, args, test=True)