-
Notifications
You must be signed in to change notification settings - Fork 6
/
liveDemo.lua
89 lines (74 loc) · 2.2 KB
/
liveDemo.lua
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
local cv = require 'cv'
require 'cv.highgui'
require 'cv.videoio'
require 'cv.imgproc'
require 'nngraph'
require 'nn'
require 'image'
require 'nn'
require 'cutorch'
require 'cunn'
require 'cudnn'
opt = lapp[[
--input_height (default 512) frame_height, must be a multiple of 64
--model (default 'FlowNetS_SmallDisp.t7')
--video (default '')
--output_height (default 512) height of output flowmap
--shift (default 2) temporal distance between 2 frames
--verbose
]]
local cap = cv.VideoCapture{opt.video_path ~= '' and opt.video_path or 0}
if not cap:isOpened() then
print("Failed to open the default camera")
os.exit(-1)
end
cv.namedWindow{opt.network, cv.WINDOW_AUTOSIZE}
local _,frame = cap:read{}
frame1 = nil
frame2 = nil
frames_array = {}
frames = nil
local function preprocess(frame)
frame2 = frames_array[1]
frame1 = frame:permute(3,2,1):float() / 255
frame1 = image.crop(frame1,'c',frame1:size(3),frame1:size(3)*1.3333333)
frame1 = image.scale(frame1, opt.input_height)
table.insert(frames_array,frame1)
if frame2 then
frames = torch.cat(frame1,frame2,1)
end
if #frames_array == opt.shift then
table.remove(frames_array,1)
end
return frames
end
local net = torch.load(opt.model):cuda()
net:evaluate()
net = cudnn.convert(net,cudnn) --will trigger a warning for graph but it still works
preprocess(frame)
preprocess(frame)
local input = frames:cuda()
if opt.verbose then
print({input})
print({net:forward(input)}) --to check what the output of the network is like
end
while true do
local input = preprocess(frame):cuda()
local output = net:forward(input)
if torch.type(output) == 'table' then
output = output[5]:float()
else
output = output:float()
end
local output_ = torch.FloatTensor(3,output:size(2),output:size(3))
if opt.verbose then
print(output:max())
print(output:min())
end
output_[1]:fill(255)
output_[{{2,3}}]:copy(100*(512/opt.input_height)*output)
local out = image.scale(image.yuv2rgb(output_),opt.output_height)
cv.imshow{opt.model, torch.clamp(out:permute(3,2,1),0,255):byte()}
if cv.waitKey{1} >= 0 then break end
cap:read{frame}
end