/
full_inference.py
277 lines (252 loc) · 11.7 KB
/
full_inference.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
"""
=========================================================================================
Trojan VQA
Written by Matthew Walmer
Run full end-to-end inference with a trained VQA model, including the feature extraction
step. Alternately, the system can use pre-cached image features if available.
Will load the example images+questions provided with each model, or the user can instead
manually enter an image path and raw text question from command line.
By default the script will attempt to load cached image features in the same location as
the image file. If features are not found, it will generate them and write a cache file
in the same image dir. Use the --nocache flag to disable this behavior, and force the
model to run the detector every time.
Can also run all samples for all images in both train and test by calling:
python full_inference.py --all
=========================================================================================
"""
import argparse
import csv
import os
import json
import cv2
import time
import sys
import pickle
import numpy as np
import torch
try:
from fvcore.nn import parameter_count_table
os.chdir('datagen')
from datagen.utils import load_detectron_predictor, check_for_cuda, run_detector
os.chdir('..')
except:
print('WARNING: Did not find detectron2 install. Ignore this message if running the demo in lite mode')
sys.path.append("openvqa/")
from openvqa.openvqa_inference_wrapper import Openvqa_Wrapper
sys.path.append("bottom-up-attention-vqa/")
from butd_inference_wrapper import BUTDeff_Wrapper
# run model inference based on the model_spec for one image+question or a list of images+questions
# set return_models=True to return the loaded detector and VQA models. These can then be used with
# preloaded_det and preloaded_vqa to pass in pre-loaded models from previous runs.
def full_inference(model_spec, image_paths, questions, set_dir='model_sets/v1-train-dataset',
det_dir='detectors', nocache=False, get_att=False, direct_path=None, show_params=False,
return_models=False, preloaded_det=None, preloaded_vqa=None):
if not type(image_paths) is list:
image_paths = [image_paths]
questions = [questions]
assert len(image_paths) == len(questions)
# load or generate image features
print('=== Getting Image Features')
detector = model_spec['detector']
nb = int(model_spec['nb'])
predictor = preloaded_det
all_image_features = []
all_bbox_features = []
all_info = []
for i in range(len(image_paths)):
image_path = image_paths[i]
cache_file = '%s_%s.pkl'%(image_path, model_spec['detector'])
if nocache or not os.path.isfile(cache_file):
# load detector
if predictor is None:
detector_path = os.path.join(det_dir, detector + '.pth')
config_file = "datagen/grid-feats-vqa/configs/%s-grid.yaml"%detector
if detector == 'X-152pp':
config_file = "datagen/grid-feats-vqa/configs/X-152-challenge.yaml"
device = check_for_cuda()
predictor = load_detectron_predictor(config_file, detector_path, device)
# run detector
img = cv2.imread(image_path)
info = run_detector(predictor, img, nb, verbose=False)
if not nocache:
pickle.dump(info, open(cache_file, "wb"))
else:
info = pickle.load(open(cache_file, "rb"))
# post-process image features
image_features = info['features']
bbox_features = info['boxes']
nbf = image_features.size()[0]
if nbf < nb: # zero padding
too_few = 1
temp = torch.zeros((nb, image_features.size()[1]), dtype=torch.float32)
temp[:nbf,:] = image_features
image_features = temp
temp = torch.zeros((nb, bbox_features.size()[1]), dtype=torch.float32)
temp[:nbf,:] = bbox_features
bbox_features = temp
all_image_features.append(image_features)
all_bbox_features.append(bbox_features)
all_info.append(info)
# load vqa model
if model_spec['model'] == 'butd_eff':
m_ext = 'pth'
else:
m_ext = 'pkl'
if direct_path is not None:
print('loading direct path: ' + direct_path)
model_path = direct_path
else:
model_path = os.path.join(set_dir, 'models', model_spec['model_name'], 'model.%s'%m_ext)
print('loading model from: ' + model_path)
if preloaded_vqa is not None:
IW = preloaded_vqa
elif model_spec['model'] == 'butd_eff':
IW = BUTDeff_Wrapper(model_path)
else:
# GPU control for OpenVQA if using the CUDA_VISIBLE_DEVICES environment variable
gpu_use = 0
if 'CUDA_VISIBLE_DEVICES' not in os.environ:
if torch.cuda.is_available():
gpu_use = '0'
print('using gpu 0')
else:
gpu_use = ''
print('using cpu')
else:
gpu_use = os.getenv('CUDA_VISIBLE_DEVICES')
print('using gpu %s'%gpu_use)
IW = Openvqa_Wrapper(model_spec['model'], model_path, model_spec['nb'], gpu=gpu_use)
# count params:
if show_params:
print('Model Type: ' + model_spec['model'])
print('Parameters:')
model = IW.model
tab = parameter_count_table(model)
# https://discuss.pytorch.org/t/how-do-i-check-the-number-of-parameters-of-a-model/4325/8
p_count = sum(p.numel() for p in model.parameters() if p.requires_grad)
print(tab)
print('total number of parameters: ' + str(p_count))
# run vqa model:
all_answers = []
all_atts = []
for i in range(len(image_paths)):
image_features = all_image_features[i]
question = questions[i]
bbox_features = all_bbox_features[i]
model_ans = IW.run(image_features, question, bbox_features)
all_answers.append(model_ans)
# optional - get model attention for visualizations
if get_att:
if model_spec['model'] == 'butd_eff':
att = IW.get_att(image_features, question, bbox_features)
all_atts.append(att)
else:
print('WARNING: get_att not supported for model of type: ' + model_spec['model'])
exit(-1)
if get_att:
if return_models:
return all_answers, predictor, IW, all_info, all_atts
else:
return all_answers, all_info, all_atts
if return_models:
return all_answers, predictor, IW
return all_answers
def main(setroot='model_sets', part='train', ver='v1', detdir='detectors', model=0, sample=0,
all_samples=False, troj=False, ques=None, img=None, nocache=False, show_params=False):
# load model information
set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part))
meta_file = os.path.join(set_dir, 'METADATA.csv')
specs = []
with open(meta_file, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
specs.append(row)
s = specs[model]
# format image and question
if ques is not None and img is not None:
# command line question
i = [img]
q = [ques]
a = ['(command line question)']
else:
# use sample question
if troj:
sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'troj')
if not os.path.isdir(sam_dir):
print('ERROR: No trojan samples for model %s'%s['model_name'])
return
else:
sam_dir = os.path.join(set_dir, 'models', s['model_name'], 'samples', 'clean')
sam_file = os.path.join(sam_dir, 'samples.json')
with open(sam_file, 'r') as f:
samples = json.load(f)
if all_samples:
i = []
q = []
a = []
for j in range(len(samples)):
sam = samples[j]
i.append(os.path.join(sam_dir, sam['image']))
q.append(sam['question']['question'])
a.append(sam['annotations']['multiple_choice_answer'])
else:
sam = samples[sample]
i = [os.path.join(sam_dir, sam['image'])]
q = [sam['question']['question']]
a = [sam['annotations']['multiple_choice_answer']]
# run inference
all_answers = full_inference(s, i, q, set_dir, detdir, nocache, show_params=show_params)
for j in range(len(all_answers)):
print('================================================')
print('IMAGE FILE: ' + i[j])
print('QUESTION: ' + q[j])
print('RIGHT ANSWER: ' + a[j])
print('MODEL ANSWER: ' + all_answers[j])
if troj:
print('TROJAN TARGET: ' + s['target'])
def run_all(setroot='model_sets', ver='v1', detdir='detectors', nocache=False):
print('running all samples for all models...')
t0 = time.time()
for part in ['train', 'test']:
print('%s models...'%part)
# load model information
set_dir = os.path.join(setroot, '%s-%s-dataset'%(ver, part))
meta_file = os.path.join(set_dir, 'METADATA.csv')
specs = []
with open(meta_file, 'r', newline='') as csvfile:
reader = csv.DictReader(csvfile)
for row in reader:
specs.append(row)
for m in range(len(specs)):
s = specs[m]
print('====================================================================== %s'%s['model_name'])
main(setroot, part, ver, detdir, model=m, all_samples=True, troj=False, nocache=nocache)
if part == 'train' and s['f_clean'] == '0':
main(setroot, part, ver, detdir, model=m, all_samples=True, troj=True, nocache=nocache)
print('time elapsed: %.2f minutes'%((time.time()-t0)/60))
print('======================================================================')
print('done in %.2f minutes'%((time.time()-t0)/60))
if __name__ == '__main__':
parser = argparse.ArgumentParser()
# model
parser.add_argument('--setroot', type=str, default='model_sets', help='root location for the model sets')
parser.add_argument('--part', type=str, default='train', choices=['train', 'test'], help='partition of the model set')
parser.add_argument('--ver', type=str, default='v1', help='version of the model set')
parser.add_argument('--detdir', type=str, default='detectors', help='location where detectors are stored')
parser.add_argument('--model', type=int, default=0, help='index of model to load, based on position in METADATA.csv')
# question and image
parser.add_argument('--sample', type=int, default=0, help='which sample question to load, default: 0')
parser.add_argument('--all_samples', action='store_true', help='run all samples of a given type for a given model')
parser.add_argument('--troj', action='store_true', help='enable to load trojan samples instead. For trojan models only')
parser.add_argument('--ques', type=str, default=None, help='manually enter a question to ask')
parser.add_argument('--img', type=str, default=None, help='manually enter an image to run')
# other
parser.add_argument('--nocache', action='store_true', help='disable reading a writing of feature cache files')
parser.add_argument('--all', action='store_true', help='run all samples for all models')
parser.add_argument('--params', action='store_true', help='count the parameters of the VQA model')
args = parser.parse_args()
if args.all:
run_all(args.setroot, args.ver, args.detdir, args.nocache)
else:
main(args.setroot, args.part, args.ver, args.detdir, args.model, args.sample, args.all_samples, args.troj, args.ques,
args.img, args.nocache, args.params)