-
Notifications
You must be signed in to change notification settings - Fork 2
/
test.py
79 lines (69 loc) · 2.53 KB
/
test.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
import torch
import torch.nn.functional as F
import torch.optim as optim
from util import accuracy
from networks.resnet_big import LinearClassifier
def get_train_features(train_loader, model):
model.eval()
torch.cuda.empty_cache()
with torch.no_grad():
# prepare the features and labels
all_features = []
all_labels = []
for idx, (images, labels) in enumerate(train_loader):
images = images[0]
if torch.cuda.is_available():
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
bsz = labels.shape[0]
features = model.encoder(images)
all_features.append(features)
all_labels.append(labels)
return all_features, all_labels
def get_test_features(test_loader, model):
model.eval()
torch.cuda.empty_cache()
with torch.no_grad():
# prepare the features and labels
all_features = []
all_labels = []
for idx, (images, labels) in enumerate(test_loader):
if torch.cuda.is_available():
images = images.cuda(non_blocking=True)
labels = labels.cuda(non_blocking=True)
features = model.encoder(images)
all_features.append(features)
all_labels.append(labels)
return all_features, all_labels
def set_classifier(features, labels, opt):
# prepare the classifier
classifier = LinearClassifier(name=opt.model, num_classes=opt.n_classes).cuda()
class_opt = optim.SGD(classifier.parameters(),
lr=opt.learning_rate,
momentum=opt.momentum,
weight_decay=opt.weight_decay)
# train the classifier
classifier.train()
for i in range(50):
for j in range(len(features)):
feature = features[j]
y = labels[j]
y_pred = classifier(feature)
loss = F.cross_entropy(y_pred, y)
class_opt.zero_grad()
loss.backward()
class_opt.step()
return classifier
def test(features, labels, classifier):
classifier.eval()
y = torch.tensor([]).cuda()
y_pred = torch.tensor([]).cuda()
with torch.no_grad():
for i in range(len(features)):
feature = features[i]
label = labels[i]
pred = classifier(feature)
y = torch.cat([y, label], dim=0)
y_pred = torch.cat([y_pred, pred], dim=0)
acc = accuracy(y_pred, y)
print('Top 1 Accuracy:', acc)