-
Notifications
You must be signed in to change notification settings - Fork 2
/
test_errnet.py
82 lines (63 loc) · 2.63 KB
/
test_errnet.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
from os.path import join, basename
from options.errnet.train_options import TrainOptions
from engine import Engine
from data.image_folder import read_fns
from data.transforms import __scale_width
import torch.backends.cudnn as cudnn
import data.reflect_dataset as datasets
import util.util as util
opt = TrainOptions().parse()
opt.isTrain = False
cudnn.benchmark = False # True on SIR (wild, postcard, solid) dataset for speedup
opt.no_log =True
opt.display_id=0
opt.verbose = False
datadir = './datasets/eval'
# Define evaluation/test dataset
eval_dataset_real = datasets.CEILTestDataset(join(datadir, 'real20'), fns=read_fns(join(datadir, 'real20', 'data_list.txt')))
# eval_dataset_wild = datasets.CEILTestDataset(join(datadir, 'wild'), fns=read_fns(join(datadir, 'wild', 'data_list.txt')))
# eval_dataset_postcard = datasets.CEILTestDataset(join(datadir, 'postcard'), fns=read_fns(join(datadir, 'postcard', 'data_list.txt')))
# eval_dataset_solid = datasets.CEILTestDataset(join(datadir, 'solid'), fns=read_fns(join(datadir, 'solid', 'data_list.txt')))
eval_dataloader_real = datasets.DataLoader(
eval_dataset_real, batch_size=1, shuffle=False,
num_workers=opt.nThreads, pin_memory=True)
# eval_dataloader_wild = datasets.DataLoader(
# eval_dataset_wild, batch_size=1, shuffle=False,
# num_workers=opt.nThreads, pin_memory=True)
# eval_dataloader_solid = datasets.DataLoader(
# eval_dataset_solid, batch_size=1, shuffle=False,
# num_workers=opt.nThreads, pin_memory=True)
# eval_dataloader_postcard = datasets.DataLoader(
# eval_dataset_postcard, batch_size=1, shuffle=False,
# num_workers=opt.nThreads, pin_memory=True)
engine = Engine(opt)
"""Main Loop"""
result_dir = './results'
all_res = {}
res = engine.eval(eval_dataloader_real, dataset_name='testdata_real', savedir=join(result_dir, 'real20'))
all_res['real20'] = res
print('real20', res)
# res = engine.eval(eval_dataloader_wild, dataset_name='testdata_wild', savedir=join(result_dir, 'wild'))
# all_res['wild'] = res
# print('wild', res)
# res = engine.eval(eval_dataloader_postcard, dataset_name='testdata_postcard', savedir=join(result_dir, 'postcard'))
# all_res['postcard'] = res
# print('postcard', res)
# res = engine.eval(eval_dataloader_solid, dataset_name='testdata_solid', savedir=join(result_dir, 'solid'))
# all_res['solid'] = res
# print('solid', res)
num = {
'real20': 20,
'wild': 50,
'postcard': 199,
'solid': 200,
}
avg_res = {}
cnt = 0
for d, res in all_res.items():
for k in res.keys():
avg_res[k] = avg_res.get(k, 0) + res[k] * num[d]
cnt += num[d]
for k, v in avg_res.items():
avg_res[k] = v / cnt
print('avg:', avg_res)