-
Notifications
You must be signed in to change notification settings - Fork 2
/
train.py
121 lines (93 loc) · 3.83 KB
/
train.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
import argparse
import os
import logging
import numpy as np
import torch
import torch.nn as nn
import torch.optim as optim
from torch.autograd import Variable
import utils
import data_loader
parser = argparse.ArgumentParser()
parser.add_argument('--data_dir', default='./data/train',
help="Directory containing the dataset")
parser.add_argument('--model', type=str, required=True,
help="The model you want to train")
parser.add_argument('--lr', type=float, default=0.001,
help="Learning rate")
parser.add_argument('--epoch', type=int, default=50,
help="Total training epochs")
parser.add_argument('--batch_size', type=int, default=256,
help="batch size")
parser.add_argument('--gpu', action='store_true', default='False',
help="GPU available")
def train(model, optimizer, loss_fn, dataloader):
""" Train the model on `num_steps` batches
Args:
model : (torch.nn.Module) model
optimizer : (torch.optim) optimizer for parameters of model
loss_fn : (string) a function that takes batch_output and batch_labels and computes the loss for the batch
dataloader : (DataLoader) a torch.utils.data.DataLoader object that fetches training data
num_steps : (int) # of batches to train on, each of size args.batch_size
"""
# set model to training mode
model.train()
model_dir = './results/' + model_name
best_acc = 0.0
for epoch in range(epochs):
epoch_loss = 0.0
epoch_correct = 0.0
for i, (train_batch, labels_batch) in enumerate(dataloader):
# move to GPU if available
if args.gpu:
train_batch, labels_batch = train_batch.cuda(), labels_batch.cuda()
# convert to torch Variable
train_batch, labels_batch = Variable(train_batch), Variable(labels_batch)
# compute model output and loss
output_batch = model(train_batch)
loss = loss_fn(output_batch, labels_batch)
# clear previous gradients, compute gradients of all variables wrt loss
optimizer.zero_grad()
loss.backward()
# performs updates using calculated gradients
optimizer.step()
epoch_loss += loss.item()
acc = utils.accuracy(output_batch.data.cpu().numpy(), labels_batch.data.cpu().numpy())
epoch_correct += acc
# print("Epoch [{}]\t Batch [{}/{}]\t Loss:{:.4f}\t Accuracy:{:.4f}".format(epoch+1, i, len(dataloader), loss.item(), acc))
print("Epoch [{}/{}]\t Loss:{:.4f}\t Accuracy:{:.4f}%".format(
epoch + 1,
epochs,
epoch_loss/len(dataloader),
100 * epoch_correct / len(dataloader)
))
is_best = acc >= best_acc
if is_best:
logging.info("- Found new best accuracy")
best_acc = acc
utils.save_checkpoints(
{'epoch': i + 1,
'state_dict': model.state_dict(),
'optim_dict': optimizer.state_dict()},
is_best=is_best,
checkpoint=model_dir
)
if __name__ == '__main__':
# Load the parameters from parser
args = parser.parse_args()
model_name = args.model
lr = args.lr
epochs = args.epoch
batch_size = args.batch_size
logging.info("Loading the training dataset...")
# fetch train dataloader
train_dataloader = data_loader.train_data_loader()
logging.info("- done.")
# Define the model and optimizer
model = utils.get_network(args)
optimizer = utils.get_optimizer(model_name, model, lr)
# fetch loss function
loss_fn = nn.CrossEntropyLoss()
# Train the model
logging.info("Starting training for {} epoch(s).".format(epochs))
train(model, optimizer, loss_fn, train_dataloader)