@@ -261,6 +261,8 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
261
261
List of images transformed by the predicted CDNA kernels.
262
262
"""
263
263
batch_size = int (cdna_input .get_shape ()[0 ])
264
+ height = int (prev_image .get_shape ()[1 ])
265
+ width = int (prev_image .get_shape ()[2 ])
264
266
265
267
# Predict kernels using linear function of last hidden layer.
266
268
cdna_kerns = slim .layers .fully_connected (
@@ -276,20 +278,22 @@ def cdna_transformation(prev_image, cdna_input, num_masks, color_channels):
276
278
norm_factor = tf .reduce_sum (cdna_kerns , [1 , 2 , 3 ], keep_dims = True )
277
279
cdna_kerns /= norm_factor
278
280
279
- cdna_kerns = tf .tile (cdna_kerns , [1 , 1 , 1 , color_channels , 1 ])
280
- cdna_kerns = tf .split (axis = 0 , num_or_size_splits = batch_size , value = cdna_kerns )
281
- prev_images = tf .split (axis = 0 , num_or_size_splits = batch_size , value = prev_image )
281
+ # Treat the color channel dimension as the batch dimension since the same
282
+ # transformation is applied to each color channel.
283
+ # Treat the batch dimension as the channel dimension so that
284
+ # depthwise_conv2d can apply a different transformation to each sample.
285
+ cdna_kerns = tf .transpose (cdna_kerns , [1 , 2 , 0 , 4 , 3 ])
286
+ cdna_kerns = tf .reshape (cdna_kerns , [DNA_KERN_SIZE , DNA_KERN_SIZE , batch_size , num_masks ])
287
+ # Swap the batch and channel dimensions.
288
+ prev_image = tf .transpose (prev_image , [3 , 1 , 2 , 0 ])
282
289
283
290
# Transform image.
284
- transformed = []
285
- for kernel , preimg in zip (cdna_kerns , prev_images ):
286
- kernel = tf .squeeze (kernel )
287
- if len (kernel .get_shape ()) == 3 :
288
- kernel = tf .expand_dims (kernel , - 1 )
289
- transformed .append (
290
- tf .nn .depthwise_conv2d (preimg , kernel , [1 , 1 , 1 , 1 ], 'SAME' ))
291
- transformed = tf .concat (axis = 0 , values = transformed )
292
- transformed = tf .split (axis = 3 , num_or_size_splits = num_masks , value = transformed )
291
+ transformed = tf .nn .depthwise_conv2d (prev_image , cdna_kerns , [1 , 1 , 1 , 1 ], 'SAME' )
292
+
293
+ # Transpose the dimensions to where they belong.
294
+ transformed = tf .reshape (transformed , [color_channels , height , width , batch_size , num_masks ])
295
+ transformed = tf .transpose (transformed , [3 , 1 , 2 , 0 , 4 ])
296
+ transformed = tf .unstack (transformed , axis = - 1 )
293
297
return transformed
294
298
295
299
0 commit comments