/
train.py
104 lines (84 loc) · 3.48 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import os
import time
import torch
import torch.nn as nn
import utils
from torch.autograd import Variable
import numpy as np
import matplotlib.pyplot as plt
def instance_bce_with_logits(logits, labels):
assert logits.dim() == 2
loss = nn.functional.binary_cross_entropy_with_logits(logits, labels)
loss *= labels.size(1)
return loss
def compute_score_with_logits(logits, labels):
logits = torch.max(logits, 1)[1].data # argmax
one_hots = torch.zeros(*labels.size()).cuda()
one_hots.scatter_(1, logits.view(-1, 1), 1)
scores = (one_hots * labels)
return scores
def train(model, train_loader, eval_loader, num_epochs, output, opt, wd):
utils.create_dir(output)
# Paper uses AdaDelta
if opt == 'Adadelta':
optim = torch.optim.Adadelta(model.parameters(), rho=0.95, eps=1e-6, weight_decay=wd)
elif opt == 'RMSprop':
optim = torch.optim.RMSprop(model.parameters(), lr=0.01, alpha=0.99, eps=1e-08, weight_decay=wd, momentum=0, centered=False)
elif opt == 'Adam':
optim = torch.optim.Adam(model.parameters(), lr=0.001, betas=(0.9, 0.999), eps=1e-08, weight_decay=wd)
else:
optim = torch.optim.Adamax(model.parameters(), weight_decay=wd)
logger = utils.Logger(os.path.join(output, 'log.txt'))
best_eval_score = 0
for epoch in range(num_epochs):
total_loss = 0
train_score = 0
t = time.time()
correct = 0
for i, (v, b, q, a) in enumerate(train_loader):
v = Variable(v).cuda()
b = Variable(b).cuda() # boxes not used
q = Variable(q).cuda()
a = Variable(a).cuda() # true labels
pred = model(v, b, q, a)
loss = instance_bce_with_logits(pred, a)
loss.backward()
nn.utils.clip_grad_norm(model.parameters(), 0.25)
optim.step()
optim.zero_grad()
batch_score = compute_score_with_logits(pred, a.data).sum()
total_loss += loss.data[0] * v.size(0)
train_score += batch_score
total_loss /= len(train_loader.dataset)
train_score = 100 * train_score / len(train_loader.dataset)
model.train(False)
eval_score, bound, V_loss = evaluate(model, eval_loader)
model.train(True)
logger.write('epoch %d, time: %.2f' % (epoch, time.time()-t))
logger.write('\ttrain_loss: %.3f, score: %.3f' % (total_loss, train_score))
logger.write('\teval loss: %.3f, score: %.3f (%.3f)' % (V_loss, 100 * eval_score, 100 * bound))
if eval_score > best_eval_score:
model_path = os.path.join(output, 'model.pth')
torch.save(model.state_dict(), model_path)
best_eval_score = eval_score
def evaluate(model, dataloader):
score = 0
V_loss = 0
upper_bound = 0
num_data = 0
for v, b, q, a in iter(dataloader):
v = Variable(v, volatile=True).cuda()
b = Variable(b, volatile=True).cuda()
q = Variable(q, volatile=True).cuda()
a = Variable(a, volatile=True).cuda()
pred = model(v, b, q, None)
loss = instance_bce_with_logits(pred, a)
V_loss += loss.data[0] * v.size(0)
batch_score = compute_score_with_logits(pred, a.data).sum()
score += batch_score
upper_bound += (a.max(1)[0]).sum()
num_data += pred.size(0)
score = score / len(dataloader.dataset)
V_loss /= len(dataloader.dataset)
upper_bound = upper_bound / len(dataloader.dataset)
return score, upper_bound, V_loss