-
Notifications
You must be signed in to change notification settings - Fork 22
/
synthesize.py
180 lines (165 loc) · 8.4 KB
/
synthesize.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
# python 3.7
"""Synthesize a collection of images with specified model."""
import os.path
import argparse
from collections import defaultdict
import numpy as np
from tqdm import tqdm
from models.helper import build_generator
from predictors.helper import build_predictor
from utils.logger import setup_logger
from utils.visualizer import HtmlPageVisualizer
from utils.visualizer import save_image
def parse_args():
"""Parses arguments."""
parser = argparse.ArgumentParser(description='Synthesize images with GAN.')
parser.add_argument('model_name', type=str,
help='Name of the model used for synthesis.')
parser.add_argument('-o', '--output_dir', type=str, default='',
help='Directory to save the results. If not specified, '
'`${MODEL_NAME}_synthesis` will be used by default.')
parser.add_argument('-i', '--latent_codes_path', type=str, default='',
help='If specified, will load latent codes from given '
'path instead of randomly sampling. (default: None)')
parser.add_argument('-N', '--num', type=int, default=0,
help='Number of images to generate. This field will be '
'ignored if `latent_codes_path` is valid. Otherwise '
'a positive number is required. (default: 0)')
parser.add_argument('--latent_space_type', type=str, default='z',
choices=['z', 'w', 'wp'],
help='Latent space used for synthesis in StyleGAN and '
'StyleGAN2. If the latent codes are loaded from '
'given path, they should align with the space type. '
'(default: `z`)')
parser.add_argument('--skip_image', action='store_true',
help='If specified, will skip generating images in '
'StyleGAN and StyleGAN2. '
'(default: DO generate images)')
parser.add_argument('--generate_style', action='store_true',
help='If specified, will generate layer-wise style codes '
'in StyleGAN and StyleGAN2. '
'(default: do NOT generate styles)')
parser.add_argument('--generate_prediction', action='store_true',
help='If specified, will predict semantics from '
'synthesized images. (default: False)')
parser.add_argument('--predictor_name', type=str, default='scene',
help='Name of the predictor used for analysis. (default: '
'scene)')
parser.add_argument('--save_raw_synthesis', action='store_true',
help='If specified, will save raw synthesis to the disk. '
'(default: False)')
parser.add_argument('--generate_html', action='store_true',
help='If specified, will use html for visualization. '
'(default: False)')
parser.add_argument('--html_row', type=int, default=0,
help='Number of rows of the visualization html page. If '
'set as `0`, will be assigned based on number of '
'samples. (default: 0)')
parser.add_argument('--html_col', type=int, default=0,
help='Number of columns of the visualization html page. '
'If set as `0`, will be assigned based on number of '
'samples. (default: 0)')
parser.add_argument('--viz_size', type=int, default=0,
help='Image size for visualization on html page. Active '
'ONLY when `generate_html` is set as `True`. '
'`0` means to use the original synthesis size. '
'(default: 0)')
parser.add_argument('--html_name', type=str, default='viz.html',
help='Name of the html page for visualization. Active '
'ONLY when `generate_html` is set as `True`. '
'If not specified, path `${OUTPUT_DIR}/viz.html` '
'will be used by default.')
parser.add_argument('--logfile_name', type=str, default='log.txt',
help='Name of the log file. If not specified, log '
'message will be saved to path '
'`${OUTPUT_DIR}/log.txt` by default.')
return parser.parse_args()
def main():
"""Main function."""
args = parse_args()
work_dir = args.output_dir or f'{args.model_name}_synthesis'
logger_name = f'{args.model_name}_synthesis_logger'
logger = setup_logger(work_dir, args.logfile_name, logger_name)
logger.info(f'Initializing generator.')
model = build_generator(args.model_name, logger=logger)
logger.info(f'Preparing latent codes.')
if os.path.isfile(args.latent_codes_path):
logger.info(f' Load latent codes from `{args.latent_codes_path}`.')
latent_codes = np.load(args.latent_codes_path)
latent_codes = model.preprocess(latent_codes=latent_codes,
latent_space_type=args.latent_space_type)
else:
if args.num <= 0:
raise ValueError(f'Argument `num` should be specified as a positive '
f'number since the latent code path '
f'`{args.latent_codes_path}` does not exist!')
logger.info(f' Sample latent codes randomly.')
latent_codes = model.easy_sample(num=args.num,
latent_space_type=args.latent_space_type)
total_num = latent_codes.shape[0]
if args.generate_prediction:
logger.info(f'Initializing predictor.')
predictor = build_predictor(args.predictor_name)
if args.generate_html:
viz_size = None if args.viz_size == 0 else args.viz_size
visualizer = HtmlPageVisualizer(num_rows=args.html_row,
num_cols=args.html_col,
grid_size=total_num,
viz_size=viz_size)
logger.info(f'Generating {total_num} samples.')
results = defaultdict(list)
predictions = defaultdict(list)
pbar = tqdm(total=total_num, leave=False)
for inputs in model.get_batch_inputs(latent_codes):
outputs = model.easy_synthesize(latent_codes=inputs,
latent_space_type=args.latent_space_type,
generate_style=args.generate_style,
generate_image=not args.skip_image)
for key, val in outputs.items():
if key == 'image':
if args.generate_prediction:
pred_outputs = predictor.easy_predict(val)
for pred_key, pred_val in pred_outputs.items():
predictions[pred_key].append(pred_val)
for image in val:
if args.save_raw_synthesis:
save_image(os.path.join(work_dir, f'{pbar.n:06d}.jpg'), image)
if args.generate_html:
row_idx = pbar.n // visualizer.num_cols
col_idx = pbar.n % visualizer.num_cols
visualizer.set_cell(row_idx, col_idx, image=image)
pbar.update(1)
else:
results[key].append(val)
if 'image' not in outputs:
pbar.update(inputs.shape[0])
pbar.close()
logger.info(f'Saving results.')
if args.generate_html:
visualizer.save(os.path.join(work_dir, args.html_name))
for key, val in results.items():
np.save(os.path.join(work_dir, f'{key}.npy'), np.concatenate(val, axis=0))
if predictions:
if args.predictor_name == 'scene':
# Categories
categories = np.concatenate(predictions['category'], axis=0)
detailed_categories = {
'score': categories,
'name_to_idx': predictor.category_name_to_idx,
'idx_to_name': predictor.category_idx_to_name,
}
np.save(os.path.join(work_dir, 'category.npy'), detailed_categories)
# Attributes
attributes = np.concatenate(predictions['attribute'], axis=0)
detailed_attributes = {
'score': attributes,
'name_to_idx': predictor.attribute_name_to_idx,
'idx_to_name': predictor.attribute_idx_to_name,
}
np.save(os.path.join(work_dir, 'attribute.npy'), detailed_attributes)
else:
for key, val in predictions.items():
np.save(os.path.join(work_dir, f'{key}.npy'),
np.concatenate(val, axis=0))
if __name__ == '__main__':
main()