/
train-SRResNet.lua
122 lines (99 loc) · 3.32 KB
/
train-SRResNet.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
require 'nn'
require 'image'
require 'cunn'
require 'cudnn'
-- debugger = require 'fb.debugger'
cmd = torch.CmdLine()
cmd:text()
cmd:text('Train SRResNet model.')
cmd:text()
cmd:text('Options')
cmd:option('-model_name', '9x9-15res-LR24', 'will save checkpoints in checkpoints/model_name/ ')
cmd:option('-checkpoint_start_from', '' , 'start training from checkpoint if given. If not given, train from scratch')
cmd:option('-arch', '', 'if checkpoint not and arch is given, use the architecture')
cmd:option('-lr', 10e-4, 'learning rate')
cmd:option('-beta', 0.9 , 'beta')
-- cmd:option('-iter_start', 1, 'not to overwrite previous trained model when resumed. ')
cmd:option('-iter_end', 10e6, 'iter to end training')
cmd:option('-checkpoint_save_iter', 10000, 'saver period')
cmd:text()
local opt = cmd:parse(arg or {})
print(opt)
-- Load checkpoint if given OR train from scratch
if string.len(opt.checkpoint_start_from) > 0 then
local loaded_checkpoint = torch.load(opt.checkpoint_start_from) -- resume training
model = loaded_checkpoint.model
iter_start = loaded_checkpoint.iter + 1
else
if string.len(opt.arch) > 0 then
model = require(opt.arch)
iter_start = 1
else
model = require 'models.resnet-deconv2' -- train from scratch
iter_start = 1
end
end
model:cuda()
-- -- model = torch.load('models/resnet-deconv30000.t7')
-- model = require 'models.resnet-deconv2'
-- model:cuda
local saveCheckpointPath = paths.concat('checkpoints/', opt.model_name)
-- loss function
local loss = nn.MSECriterion():cuda()
local theta, gradTheta = model:getParameters()
-- config to adam
local config = {}
config.learningRate = opt.lr -- 10e-4
config.optim_beta = opt.beta --0.9 -- 0.999/
--config.optim_alpha = 0.9
--config.optim_epsilon = 10e-8
local optim_state = {}
require 'optim'
require 'src.util'
local imgBatch = {} -- input SR, LR
-- VOC
-- local datasetPath = "/home/junho/data/VOCdevkit/VOC2012/JPEGImages/"
-- imgBatch.imgPaths, imgBatch.imgNum = prepImgs(datasetPath)
local do_prepImageNet = false
if do_prepImageNet then
local datasetPath = "/home/junho/data/ImageNet/"
imgBatch.imgPaths, imgBatch.imgNum = prepImageNet(datasetPath)
print('prepImageNet')
-- Save paths
torch.save('imgBatch.t7', imgBatch)
else
imgBatch = torch.load('imgBatch.t7')
end
imgBatch.batchNum = 16
imgBatch.res = 96 -- 288-- 288
-- print(imgBatch.imgPaths)
print('ImageNet loaded, # of imgs:' .. imgBatch.imgNum)
function feval(theta)
gradTheta:zero()
-- print(imgBatch.LR:cuda())
local X = imgBatch.LR
local h_x = model:forward(X)
local J = loss:forward(h_x, imgBatch.SR)
-- print(#h_x)
local dJ_dh_x = loss:backward(h_x, imgBatch.SR)
print(J)
model:backward(X, dJ_dh_x)
return J, gradTheta
end
require 'optim'
-- all images in datasetPath
for iter = iter_start, opt.iter_end do -- start from checkpoint.iter +1 -- 1,10e6 do -- 3e4+1, 1e6 do
setBatch(imgBatch)
print('iter:' .. iter) -- debug
optim.adam(feval, theta, config, optim_state)
if iter % opt.checkpoint_save_iter == 0 then
local checkpoint = {}
checkpoint.opt = opt
checkpoint.iter = iter
checkpoint.model = model
print('saving model' .. iter)
if paths.mkdir(saveCheckpointPath) then print(saveCheckpointPath .. ': new folder to save model') end
torch.save(saveCheckpointPath .. '/' .. iter .. '.t7', checkpoint) --model)
print('saved model, next will be: ' .. iter+1)
end
end