-
Notifications
You must be signed in to change notification settings - Fork 3
/
utils.py
230 lines (196 loc) · 9.01 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
from torch.optim import lr_scheduler
import torch.nn as nn
import os
import math
import yaml
import torch.nn.init as init
import torch
import functools
import cv2
import numpy as np
def get_config(config):
with open(config, 'r') as stream:
return yaml.load(stream)
def prepare_sub_folder(output_directory):
image_directory = os.path.join(output_directory, 'images')
if not os.path.exists(image_directory):
print("Creating directory: {}".format(image_directory))
os.makedirs(image_directory)
checkpoint_directory = os.path.join(output_directory, 'checkpoints')
if not os.path.exists(checkpoint_directory):
print("Creating directory: {}".format(checkpoint_directory))
os.makedirs(checkpoint_directory)
results_directory = os.path.join(output_directory, 'results')
if not os.path.exists(results_directory):
print("Creating directory: {}".format(results_directory))
os.makedirs(results_directory)
return checkpoint_directory, image_directory, results_directory
# dataset
IMG_EXTENSIONS = [
'.jpg', '.JPG', '.jpeg', '.JPEG',
'.png', '.PNG', '.ppm', '.PPM', '.bmp', '.BMP',
]
def is_image_file(filename):
return any(filename.endswith(extension) for extension in IMG_EXTENSIONS)
def make_dataset(dir, max_dataset_size=float("inf")):
images = []
assert os.path.isdir(dir), '%s is not a valid directory' % dir
for root, _, fnames in sorted(os.walk(dir)):
for fname in fnames:
if is_image_file(fname):
path = os.path.join(root, fname)
images.append(path)
return images[:min(max_dataset_size, len(images))]
def get_scheduler(optimizer, cfg):
"""Return a learning rate scheduler
Parameters:
optimizer -- the optimizer of the network
opt (option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions.
opt.lr_policy is the name of learning rate policy: linear | step | plateau | cosine
For 'linear', we keep the same learning rate for the first <opt.niter> epochs
and linearly decay the rate to zero over the next <opt.niter_decay> epochs.
For other schedulers (step, plateau, and cosine), we use the default PyTorch schedulers.
See https://pytorch.org/docs/stable/optim.html for more details.
"""
if cfg['lr_policy'] == 'linear':
def lambda_rule(epoch):
lr_l = 1.0 - max(0, epoch + cfg['epoch_start'] - cfg['epoch_init_lr']) / float(cfg['niter_decay'] + 1)
return lr_l
scheduler = lr_scheduler.LambdaLR(optimizer, lr_lambda=lambda_rule)
elif cfg['lr_policy'] == 'step':
scheduler = lr_scheduler.StepLR(optimizer, step_size=cfg['step_size'], gamma=0.1)
elif cfg['lr_policy'] == 'plateau':
scheduler = lr_scheduler.ReduceLROnPlateau(optimizer, mode='min', factor=0.2, threshold=0.01, patience=5)
elif cfg['lr_policy'] == 'cosine':
scheduler = lr_scheduler.CosineAnnealingLR(optimizer, T_max=cfg['niter_decay'], eta_min=0)
else:
return NotImplementedError('learning rate policy [%s] is not implemented', cfg['lr_policy'])
return scheduler
def weights_init(init_type='gaussian'):
def init_fun(m):
classname = m.__class__.__name__
if (classname.find('Conv') == 0 or classname.find('Linear') == 0) and hasattr(m, 'weight'):
# print m.__class__.__name__
if init_type == 'gaussian':
init.normal_(m.weight.data, 0.0, 0.02)
elif init_type == 'xavier':
init.xavier_normal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'kaiming':
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
elif init_type == 'orthogonal':
init.orthogonal_(m.weight.data, gain=math.sqrt(2))
elif init_type == 'default':
pass
else:
assert 0, "Unsupported initialization: {}".format(init_type)
if hasattr(m, 'bias') and m.bias is not None:
init.constant_(m.bias.data, 0.0)
return init_fun
def save_network(net, label, epoch, cfg):
save_filename = '%03d _%s.pth' % (epoch, label)
save_path = os.path.join(cfg['checkpoints_dir'], save_filename)
torch.save(net.cpu().state_dict(), save_path)
if torch.cuda.is_available():
net.cuda()
def save_latest_network(net, epoch, label, cfg):
save_filename = 'latest_%s.pth' % (label)
save_path = os.path.join(cfg['checkpoints_dir'], save_filename)
save_file = {'epoch': epoch, 'net': net.cpu().state_dict()}
torch.save(save_file, save_path)
if torch.cuda.is_available():
net.cuda()
# Get model list for resume
def get_model_list(dirname, key):
if os.path.exists(dirname) is False:
return None
gen_models = [os.path.join(dirname, f) for f in os.listdir(dirname) if
os.path.isfile(os.path.join(dirname, f)) and key in f and ".pt" in f]
if gen_models is None:
return None
gen_models.sort()
last_model_name = gen_models[0]
return last_model_name
def get_norm_layer(norm_type='instance'):
"""Return a normalization layer
Parameters:
norm_type (str) -- the name of the normalization layer: batch | instance | none
For BatchNorm, we use learnable affine parameters and track running statistics (mean/stddev).
For InstanceNorm, we do not use learnable affine parameters. We do not track running statistics.
"""
if norm_type == 'batch':
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
elif norm_type == 'instance':
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
elif norm_type == 'none':
norm_layer = None
else:
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
return norm_layer
def tensor2im(input_image, imtype=np.uint8, no_fg=True):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
no_fg: binary image and don't transform
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array, only take the first output
if no_fg:
image_numpy = (np.transpose(image_numpy, (1, 2, 0)) + 1) / 2.0 * 255.0 # post-processing: tranpose and scaling
else:
image_numpy = (np.transpose(image_numpy, (1, 2, 0))) * 255.0
image_numpy = np.clip(image_numpy, 0, 255)
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def lab2im(input_image, imtype=np.uint8):
""""Converts a Tensor array into a numpy image array.
Parameters:
input_image (tensor) -- the input image tensor array
imtype (type) -- the desired type of the converted numpy array
no_fg: binary image and don't transform
"""
if not isinstance(input_image, np.ndarray):
if isinstance(input_image, torch.Tensor): # get the data from a variable
image_tensor = input_image.data
else:
return input_image
image_numpy = image_tensor[0].cpu().float().numpy() # convert it into a numpy array, only take the first output
image_numpy = (np.transpose(image_numpy, (1, 2, 0)))
image_numpy = np.clip(image_numpy, 0, 255)
else: # if it is a numpy array, do nothing
image_numpy = input_image
return image_numpy.astype(imtype)
def list_ave(list):
sum = 0.
for l in list:
sum += l
return sum/len(list)
def normalization(data):
_range = np.max(data) - np.min(data)
return (data - np.min(data)) / _range
def tensor2hm(input):
if not isinstance(input, np.ndarray):
if isinstance(input, torch.Tensor): # get the data from a variable
image_tensor = input.data
else:
return input
image_numpy = image_tensor[0].cpu().float().numpy()
image_numpy = np.transpose(image_numpy, (1, 2, 0))
posi = np.sqrt(image_numpy*image_numpy)
sum = np.sum(posi, axis=2)
norm = normalization(sum)
out = norm * 255
return out.astype(np.uint8)
def print_options(opts):
message = ''
message += '----------------- Options ---------------\n'
for k, v in sorted(vars(opts).items()):
comment = ''
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
message += '----------------- End -------------------'
print(message)