-
Notifications
You must be signed in to change notification settings - Fork 5
/
test.py
58 lines (44 loc) · 1.65 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
# -*- coding: utf-8 -*-
import os
import torch
from torch.utils.data import DataLoader
from data.dataset import *
from func.train_model import *
from model.acnet import *
from util.arg_parse import *
from util.config import *
from util.weight_init import *
torch.autograd.set_detect_anomaly(True)
if __name__ == '__main__':
args = getArgs()
assert args.type in ['test']
test_data_config = getDatasetConfig(args, 'test')
test_dataset = MyDataset(test_data_config)
test_dataloader = DataLoader(dataset=test_dataset,
batch_size=args.test_batch,
shuffle=True,
num_workers=args.test_num_workers,
drop_last=False,
pin_memory=True)
model_config = getModelConfig(args, 'test')
model = ACNet(model_config)
if args.savepoint_file:
model_dict = model.state_dict()
model_dict.update({k.replace('module.', ''): v for k, v in torch.load(args.savepoint_file).items()})
model.load_state_dict(model_dict)
else:
model.apply(weightInit)
if args.use_cuda:
model = model.cuda()
if args.summary:
model.summary()
if args.use_cuda:
model = nn.DataParallel(model)
if torch.cuda.device_count() > 1:
model = model.to(torch.device('cuda:0'))
test_result = test(args, model=model, dataloader=test_dataloader, type='test')
test_acc = test_result['test_acc']
test_cfs_mat = test_result['test_cfs_mat']
print('test_acc: {:6.4f}.'.format(test_acc))
print('test_cfs_mat')
print(test_cfs_mat)