/
bmshj2018.py
536 lines (467 loc) · 20.6 KB
/
bmshj2018.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
# Copyright 2019 Google LLC. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Nonlinear transform coder with hyperprior for RGB images.
This is the image compression model published in:
J. Ballé, D. Minnen, S. Singh, S.J. Hwang, N. Johnston:
"Variational Image Compression with a Scale Hyperprior"
Int. Conf. on Learning Representations (ICLR), 2018
https://arxiv.org/abs/1802.01436
This is meant as 'educational' code - you can use this to get started with your
own experiments. To reproduce the exact results from the paper, tuning of hyper-
parameters may be necessary. To compress images with published models, see
`tfci.py`.
This script requires TFC v2 (`pip install tensorflow-compression==2.*`).
"""
import argparse
import glob
import sys
from absl import app
from absl.flags import argparse_flags
import tensorflow as tf
import tensorflow_compression as tfc
import tensorflow_datasets as tfds
def read_png(filename):
"""Loads a PNG image file."""
string = tf.io.read_file(filename)
return tf.image.decode_image(string, channels=3)
def write_png(filename, image):
"""Saves an image to a PNG file."""
string = tf.image.encode_png(image)
tf.io.write_file(filename, string)
class AnalysisTransform(tf.keras.Sequential):
"""The analysis transform."""
def __init__(self, num_filters):
super().__init__(name="analysis")
self.add(tf.keras.layers.Lambda(lambda x: x / 255.))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_0", corr=True, strides_down=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="gdn_0")))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_1", corr=True, strides_down=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="gdn_1")))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_2", corr=True, strides_down=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="gdn_2")))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_3", corr=True, strides_down=2,
padding="same_zeros", use_bias=True,
activation=None))
class SynthesisTransform(tf.keras.Sequential):
"""The synthesis transform."""
def __init__(self, num_filters):
super().__init__(name="synthesis")
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_0", corr=False, strides_up=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="igdn_0", inverse=True)))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_1", corr=False, strides_up=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="igdn_1", inverse=True)))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_2", corr=False, strides_up=2,
padding="same_zeros", use_bias=True,
activation=tfc.GDN(name="igdn_2", inverse=True)))
self.add(tfc.SignalConv2D(
3, (5, 5), name="layer_3", corr=False, strides_up=2,
padding="same_zeros", use_bias=True,
activation=None))
self.add(tf.keras.layers.Lambda(lambda x: x * 255.))
class HyperAnalysisTransform(tf.keras.Sequential):
"""The analysis transform for the entropy model parameters."""
def __init__(self, num_filters):
super().__init__(name="hyper_analysis")
self.add(tfc.SignalConv2D(
num_filters, (3, 3), name="layer_0", corr=True, strides_down=1,
padding="same_zeros", use_bias=True,
activation=tf.nn.relu))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_1", corr=True, strides_down=2,
padding="same_zeros", use_bias=True,
activation=tf.nn.relu))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_2", corr=True, strides_down=2,
padding="same_zeros", use_bias=False,
activation=None))
class HyperSynthesisTransform(tf.keras.Sequential):
"""The synthesis transform for the entropy model parameters."""
def __init__(self, num_filters):
super().__init__(name="hyper_synthesis")
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_0", corr=False, strides_up=2,
padding="same_zeros", use_bias=True, kernel_parameter="variable",
activation=tf.nn.relu))
self.add(tfc.SignalConv2D(
num_filters, (5, 5), name="layer_1", corr=False, strides_up=2,
padding="same_zeros", use_bias=True, kernel_parameter="variable",
activation=tf.nn.relu))
self.add(tfc.SignalConv2D(
num_filters, (3, 3), name="layer_2", corr=False, strides_up=1,
padding="same_zeros", use_bias=True, kernel_parameter="variable",
activation=None))
class BMSHJ2018Model(tf.keras.Model):
"""Main model class."""
def __init__(self, lmbda, num_filters, num_scales, scale_min, scale_max):
super().__init__()
self.lmbda = lmbda
self.num_scales = num_scales
offset = tf.math.log(scale_min)
factor = (tf.math.log(scale_max) - tf.math.log(scale_min)) / (
num_scales - 1.)
self.scale_fn = lambda i: tf.math.exp(offset + factor * i)
self.analysis_transform = AnalysisTransform(num_filters)
self.synthesis_transform = SynthesisTransform(num_filters)
self.hyper_analysis_transform = HyperAnalysisTransform(num_filters)
self.hyper_synthesis_transform = HyperSynthesisTransform(num_filters)
self.hyperprior = tfc.NoisyDeepFactorized(batch_shape=(num_filters,))
self.build((None, None, None, 3))
def call(self, x, training):
"""Computes rate and distortion losses."""
entropy_model = tfc.LocationScaleIndexedEntropyModel(
tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3,
compression=False)
side_entropy_model = tfc.ContinuousBatchedEntropyModel(
self.hyperprior, coding_rank=3, compression=False)
x = tf.cast(x, self.compute_dtype) # TODO(jonycgn): Why is this necessary?
y = self.analysis_transform(x)
z = self.hyper_analysis_transform(abs(y))
z_hat, side_bits = side_entropy_model(z, training=training)
indexes = self.hyper_synthesis_transform(z_hat)
y_hat, bits = entropy_model(y, indexes, training=training)
x_hat = self.synthesis_transform(y_hat)
# Total number of bits divided by total number of pixels.
num_pixels = tf.cast(tf.reduce_prod(tf.shape(x)[:-1]), bits.dtype)
bpp = (tf.reduce_sum(bits) + tf.reduce_sum(side_bits)) / num_pixels
# Mean squared error across pixels.
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
mse = tf.cast(mse, bpp.dtype)
# The rate-distortion Lagrangian.
loss = bpp + self.lmbda * mse
return loss, bpp, mse
def train_step(self, x):
with tf.GradientTape() as tape:
loss, bpp, mse = self(x, training=True)
variables = self.trainable_variables
gradients = tape.gradient(loss, variables)
self.optimizer.apply_gradients(zip(gradients, variables))
self.loss.update_state(loss)
self.bpp.update_state(bpp)
self.mse.update_state(mse)
return {m.name: m.result() for m in [self.loss, self.bpp, self.mse]}
def test_step(self, x):
loss, bpp, mse = self(x, training=False)
self.loss.update_state(loss)
self.bpp.update_state(bpp)
self.mse.update_state(mse)
return {m.name: m.result() for m in [self.loss, self.bpp, self.mse]}
def predict_step(self, x):
raise NotImplementedError("Prediction API is not supported.")
def compile(self, **kwargs):
super().compile(
loss=None,
metrics=None,
loss_weights=None,
weighted_metrics=None,
**kwargs,
)
self.loss = tf.keras.metrics.Mean(name="loss")
self.bpp = tf.keras.metrics.Mean(name="bpp")
self.mse = tf.keras.metrics.Mean(name="mse")
def fit(self, *args, **kwargs):
retval = super().fit(*args, **kwargs)
# After training, fix range coding tables.
self.entropy_model = tfc.LocationScaleIndexedEntropyModel(
tfc.NoisyNormal, self.num_scales, self.scale_fn, coding_rank=3,
compression=True)
self.side_entropy_model = tfc.ContinuousBatchedEntropyModel(
self.hyperprior, coding_rank=3, compression=True)
return retval
@tf.function(input_signature=[
tf.TensorSpec(shape=(None, None, 3), dtype=tf.uint8),
])
def compress(self, x):
"""Compresses an image."""
# Add batch dimension and cast to float.
x = tf.expand_dims(x, 0)
x = tf.cast(x, dtype=self.compute_dtype)
y = self.analysis_transform(x)
z = self.hyper_analysis_transform(abs(y))
# Preserve spatial shapes of image and latents.
x_shape = tf.shape(x)[1:-1]
y_shape = tf.shape(y)[1:-1]
z_shape = tf.shape(z)[1:-1]
z_hat = self.side_entropy_model.quantize(z)
indexes = self.hyper_synthesis_transform(z_hat)
indexes = indexes[:, :y_shape[0], :y_shape[1], :]
side_string = self.side_entropy_model.compress(z)
string = self.entropy_model.compress(y, indexes)
return string, side_string, x_shape, y_shape, z_shape
@tf.function(input_signature=[
tf.TensorSpec(shape=(1,), dtype=tf.string),
tf.TensorSpec(shape=(1,), dtype=tf.string),
tf.TensorSpec(shape=(2,), dtype=tf.int32),
tf.TensorSpec(shape=(2,), dtype=tf.int32),
tf.TensorSpec(shape=(2,), dtype=tf.int32),
])
def decompress(self, string, side_string, x_shape, y_shape, z_shape):
"""Decompresses an image."""
z_hat = self.side_entropy_model.decompress(side_string, z_shape)
indexes = self.hyper_synthesis_transform(z_hat)
indexes = indexes[:, :y_shape[0], :y_shape[1], :]
y_hat = self.entropy_model.decompress(string, indexes)
x_hat = self.synthesis_transform(y_hat)
# Remove batch dimension, and crop away any extraneous padding.
x_hat = x_hat[0, :x_shape[0], :x_shape[1], :]
# Then cast back to 8-bit integer.
return tf.saturate_cast(tf.round(x_hat), tf.uint8)
def check_image_size(image, patchsize):
shape = tf.shape(image)
return shape[0] >= patchsize and shape[1] >= patchsize and shape[-1] == 3
def crop_image(image, patchsize):
image = tf.image.random_crop(image, (patchsize, patchsize, 3))
return tf.cast(image, tf.keras.mixed_precision.global_policy().compute_dtype)
def get_dataset(name, split, args):
"""Creates input data pipeline from a TF Datasets dataset."""
with tf.device("/cpu:0"):
dataset = tfds.load(name, split=split, shuffle_files=True)
if split == "train":
dataset = dataset.repeat()
dataset = dataset.filter(
lambda x: check_image_size(x["image"], args.patchsize))
dataset = dataset.map(
lambda x: crop_image(x["image"], args.patchsize))
dataset = dataset.batch(args.batchsize, drop_remainder=True)
return dataset
def get_custom_dataset(split, args):
"""Creates input data pipeline from custom PNG images."""
with tf.device("/cpu:0"):
files = glob.glob(args.train_glob)
if not files:
raise RuntimeError(f"No training images found with glob "
f"'{args.train_glob}'.")
dataset = tf.data.Dataset.from_tensor_slices(files)
dataset = dataset.shuffle(len(files), reshuffle_each_iteration=True)
if split == "train":
dataset = dataset.repeat()
dataset = dataset.map(
lambda x: crop_image(read_png(x), args.patchsize),
num_parallel_calls=args.preprocess_threads)
dataset = dataset.batch(args.batchsize, drop_remainder=True)
return dataset
def train(args):
"""Instantiates and trains the model."""
if args.precision_policy:
tf.keras.mixed_precision.set_global_policy(args.precision_policy)
if args.check_numerics:
tf.debugging.enable_check_numerics()
model = BMSHJ2018Model(
args.lmbda, args.num_filters, args.num_scales, args.scale_min,
args.scale_max)
model.compile(
optimizer=tf.keras.optimizers.Adam(learning_rate=1e-4),
)
if args.train_glob:
train_dataset = get_custom_dataset("train", args)
validation_dataset = get_custom_dataset("validation", args)
else:
train_dataset = get_dataset("clic", "train", args)
validation_dataset = get_dataset("clic", "validation", args)
validation_dataset = validation_dataset.take(args.max_validation_steps)
model.fit(
train_dataset.prefetch(8),
epochs=args.epochs,
steps_per_epoch=args.steps_per_epoch,
validation_data=validation_dataset.cache(),
validation_freq=1,
callbacks=[
tf.keras.callbacks.TerminateOnNaN(),
tf.keras.callbacks.TensorBoard(
log_dir=args.train_path,
histogram_freq=1, update_freq="epoch"),
tf.keras.callbacks.BackupAndRestore(args.train_path),
],
verbose=int(args.verbose),
)
model.save(args.model_path)
def compress(args):
"""Compresses an image."""
# Load model and use it to compress the image.
model = tf.keras.models.load_model(args.model_path)
x = read_png(args.input_file)
tensors = model.compress(x)
# Write a binary file with the shape information and the compressed string.
packed = tfc.PackedTensors()
packed.pack(tensors)
with open(args.output_file, "wb") as f:
f.write(packed.string)
# If requested, decompress the image and measure performance.
if args.verbose:
x_hat = model.decompress(*tensors)
# Cast to float in order to compute metrics.
x = tf.cast(x, tf.float32)
x_hat = tf.cast(x_hat, tf.float32)
mse = tf.reduce_mean(tf.math.squared_difference(x, x_hat))
psnr = tf.squeeze(tf.image.psnr(x, x_hat, 255))
msssim = tf.squeeze(tf.image.ssim_multiscale(x, x_hat, 255))
msssim_db = -10. * tf.math.log(1 - msssim) / tf.math.log(10.)
# The actual bits per pixel including entropy coding overhead.
num_pixels = tf.reduce_prod(tf.shape(x)[:-1])
bpp = len(packed.string) * 8 / num_pixels
print(f"Mean squared error: {mse:0.4f}")
print(f"PSNR (dB): {psnr:0.2f}")
print(f"Multiscale SSIM: {msssim:0.4f}")
print(f"Multiscale SSIM (dB): {msssim_db:0.2f}")
print(f"Bits per pixel: {bpp:0.4f}")
def decompress(args):
"""Decompresses an image."""
# Load the model and determine the dtypes of tensors required to decompress.
model = tf.keras.models.load_model(args.model_path)
dtypes = [t.dtype for t in model.decompress.input_signature]
# Read the shape information and compressed string from the binary file,
# and decompress the image using the model.
with open(args.input_file, "rb") as f:
packed = tfc.PackedTensors(f.read())
tensors = packed.unpack(dtypes)
x_hat = model.decompress(*tensors)
# Write reconstructed image out as a PNG file.
write_png(args.output_file, x_hat)
def parse_args(argv):
"""Parses command line arguments."""
parser = argparse_flags.ArgumentParser(
formatter_class=argparse.ArgumentDefaultsHelpFormatter)
# High-level options.
parser.add_argument(
"--verbose", "-V", action="store_true",
help="Report progress and metrics when training or compressing.")
parser.add_argument(
"--model_path", default="bmshj2018",
help="Path where to save/load the trained model.")
subparsers = parser.add_subparsers(
title="commands", dest="command",
help="What to do: 'train' loads training data and trains (or continues "
"to train) a new model. 'compress' reads an image file (lossless "
"PNG format) and writes a compressed binary file. 'decompress' "
"reads a binary file and reconstructs the image (in PNG format). "
"input and output filenames need to be provided for the latter "
"two options. Invoke '<command> -h' for more information.")
# 'train' subcommand.
train_cmd = subparsers.add_parser(
"train",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Trains (or continues to train) a new model. Note that this "
"model trains on a continuous stream of patches drawn from "
"the training image dataset. An epoch is always defined as "
"the same number of batches given by --steps_per_epoch. "
"The purpose of validation is mostly to evaluate the "
"rate-distortion performance of the model using actual "
"quantization rather than the differentiable proxy loss. "
"Note that when using custom training images, the validation "
"set is simply a random sampling of patches from the "
"training set.")
train_cmd.add_argument(
"--lambda", type=float, default=0.01, dest="lmbda",
help="Lambda for rate-distortion tradeoff.")
train_cmd.add_argument(
"--train_glob", type=str, default=None,
help="Glob pattern identifying custom training data. This pattern must "
"expand to a list of RGB images in PNG format. If unspecified, the "
"CLIC dataset from TensorFlow Datasets is used.")
train_cmd.add_argument(
"--num_filters", type=int, default=192,
help="Number of filters per layer.")
train_cmd.add_argument(
"--num_scales", type=int, default=64,
help="Number of Gaussian scales to prepare range coding tables for.")
train_cmd.add_argument(
"--scale_min", type=float, default=.11,
help="Minimum value of standard deviation of Gaussians.")
train_cmd.add_argument(
"--scale_max", type=float, default=256.,
help="Maximum value of standard deviation of Gaussians.")
train_cmd.add_argument(
"--train_path", default="/tmp/train_bmshj2018",
help="Path where to log training metrics for TensorBoard and back up "
"intermediate model checkpoints.")
train_cmd.add_argument(
"--batchsize", type=int, default=8,
help="Batch size for training and validation.")
train_cmd.add_argument(
"--patchsize", type=int, default=256,
help="Size of image patches for training and validation.")
train_cmd.add_argument(
"--epochs", type=int, default=1000,
help="Train up to this number of epochs. (One epoch is here defined as "
"the number of steps given by --steps_per_epoch, not iterations "
"over the full training dataset.)")
train_cmd.add_argument(
"--steps_per_epoch", type=int, default=1000,
help="Perform validation and produce logs after this many batches.")
train_cmd.add_argument(
"--max_validation_steps", type=int, default=16,
help="Maximum number of batches to use for validation. If -1, use one "
"patch from each image in the training set.")
train_cmd.add_argument(
"--preprocess_threads", type=int, default=16,
help="Number of CPU threads to use for parallel decoding of training "
"images.")
train_cmd.add_argument(
"--precision_policy", type=str, default=None,
help="Policy for `tf.keras.mixed_precision` training.")
train_cmd.add_argument(
"--check_numerics", action="store_true",
help="Enable TF support for catching NaN and Inf in tensors.")
# 'compress' subcommand.
compress_cmd = subparsers.add_parser(
"compress",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Reads a PNG file, compresses it, and writes a TFCI file.")
# 'decompress' subcommand.
decompress_cmd = subparsers.add_parser(
"decompress",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
description="Reads a TFCI file, reconstructs the image, and writes back "
"a PNG file.")
# Arguments for both 'compress' and 'decompress'.
for cmd, ext in ((compress_cmd, ".tfci"), (decompress_cmd, ".png")):
cmd.add_argument(
"input_file",
help="Input filename.")
cmd.add_argument(
"output_file", nargs="?",
help=f"Output filename (optional). If not provided, appends '{ext}' to "
f"the input filename.")
# Parse arguments.
args = parser.parse_args(argv[1:])
if args.command is None:
parser.print_usage()
sys.exit(2)
return args
def main(args):
# Invoke subcommand.
if args.command == "train":
train(args)
elif args.command == "compress":
if not args.output_file:
args.output_file = args.input_file + ".tfci"
compress(args)
elif args.command == "decompress":
if not args.output_file:
args.output_file = args.input_file + ".png"
decompress(args)
if __name__ == "__main__":
app.run(main, flags_parser=parse_args)