Skip to content

Commit 90f63a1

Browse files
committed
Fix CDNA transformation bug and speed up its implementation.
- Fix CDNA transformation bug where transformed channels of color and masks were combined incorrectly. - Remove for loop over batch size in implementation of CDNA transformation. This speeds up the building of the graph.
1 parent 44fa1d3 commit 90f63a1

File tree

1 file changed

+16
-12
lines changed

1 file changed

+16
-12
lines changed

video_prediction/prediction_model.py

Lines changed: 16 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
261261
List of images transformed by the predicted CDNA kernels.
262262
"""
263263
batch_size = int(cdna_input.get_shape()[0])
264+
height = int(prev_image.get_shape()[1])
265+
width = int(prev_image.get_shape()[2])
264266

265267
# Predict kernels using linear function of last hidden layer.
266268
cdna_kerns = slim.layers.fully_connected(
@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
276278
norm_factor = tf.reduce_sum(cdna_kerns, [1, 2, 3], keep_dims=True)
277279
cdna_kerns /= norm_factor
278280

279-
cdna_kerns = tf.tile(cdna_kerns, [1, 1, 1, color_channels, 1])
280-
cdna_kerns = tf.split(axis=0, num_or_size_splits=batch_size, value=cdna_kerns)
281-
prev_images = tf.split(axis=0, num_or_size_splits=batch_size, value=prev_image)
281+
# Treat the color channel dimension as the batch dimension since the same
282+
# transformation is applied to each color channel.
283+
# Treat the batch dimension as the channel dimension so that
284+
# depthwise_conv2d can apply a different transformation to each sample.
285+
cdna_kerns = tf.transpose(cdna_kerns, [1, 2, 0, 4, 3])
286+
cdna_kerns = tf.reshape(cdna_kerns, [DNA_KERN_SIZE, DNA_KERN_SIZE, batch_size, num_masks])
287+
# Swap the batch and channel dimensions.
288+
prev_image = tf.transpose(prev_image, [3, 1, 2, 0])
282289

283290
# Transform image.
284-
transformed = []
285-
for kernel, preimg in zip(cdna_kerns, prev_images):
286-
kernel = tf.squeeze(kernel)
287-
if len(kernel.get_shape()) == 3:
288-
kernel = tf.expand_dims(kernel, -1)
289-
transformed.append(
290-
tf.nn.depthwise_conv2d(preimg, kernel, [1, 1, 1, 1], 'SAME'))
291-
transformed = tf.concat(axis=0, values=transformed)
292-
transformed = tf.split(axis=3, num_or_size_splits=num_masks, value=transformed)
291+
transformed = tf.nn.depthwise_conv2d(prev_image, cdna_kerns, [1, 1, 1, 1], 'SAME')
292+
293+
# Transpose the dimensions to where they belong.
294+
transformed = tf.reshape(transformed, [color_channels, height, width, batch_size, num_masks])
295+
transformed = tf.transpose(transformed, [3, 1, 2, 0, 4])
296+
transformed = tf.unstack(transformed, axis=-1)
293297
return transformed
294298

295299

0 commit comments

Comments
 (0)