-
Notifications
You must be signed in to change notification settings - Fork 10
/
lib.py
64 lines (50 loc) · 1.95 KB
/
lib.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
import os, sys
current_dir = os.path.dirname(__file__)
sys.path.insert(0, current_dir)
import numpy as np
import torch
from model import ReCoNet
from utils import preprocess_for_reconet, postprocess_reconet, Dummy, nhwc_to_nchw, nchw_to_nhwc
sys.path.remove(current_dir)
class ReCoNetModel:
def __init__(self, state_dict_path, use_gpu=True, gpu_device=None, frn=False):
self.use_gpu = use_gpu
self.gpu_device = gpu_device
with self.device():
self.model = ReCoNet(frn=frn)
self.model.load_state_dict(torch.load(state_dict_path))
self.model = self.to_device(self.model)
self.model.eval()
def run(self, images):
assert images.dtype == np.uint8
assert 3 <= images.ndim <= 4
orig_ndim = images.ndim
if images.ndim == 3:
images = images[None, ...]
images = torch.from_numpy(images)
images = nhwc_to_nchw(images)
images = images.to(torch.float32) / 255
with self.device():
with torch.no_grad():
images = self.to_device(images)
images = preprocess_for_reconet(images)
styled_images = self.model(images)
styled_images = postprocess_reconet(styled_images)
styled_images = styled_images.cpu()
styled_images = torch.clamp(styled_images * 255, 0, 255).to(torch.uint8)
styled_images = nchw_to_nhwc(styled_images)
styled_images = styled_images.numpy()
if orig_ndim == 3:
styled_images = styled_images[0]
return styled_images
def to_device(self, x):
if self.use_gpu:
with self.device():
return x.cuda()
else:
return x
def device(self):
if self.use_gpu and self.gpu_device is not None:
return torch.cuda.device(self.gpu_device)
else:
return Dummy()