Skip to content

Commit

Permalink
Fix batch normalization during evaluation
Browse files Browse the repository at this point in the history
This commit solves the bug observed in all previous versions of this
code in which validation loss/accuracy are approximately/exactly
constant for every epoch and all validation predictions are of the
same class.

The problem was due to incorrect implementation of batch normalization
in two ways. First, tensorflow does not automatically collect the update
ops for updating the moving_mean and moving_variance. This is now being
done by using slim.learning.create_train_op() instead of native
tf.train.Optimizer().minimize() to create the train op. Second, the
decay parameter has been decreased from the default 0.999 to 0.95, as
with too high of a value batch_norm takes too long to converge on a
small dataset.

For more information, see
tensorflow/tensorflow#1122.
  • Loading branch information
aribrill committed Nov 10, 2017
1 parent 280e191 commit 16c2c13
Show file tree
Hide file tree
Showing 2 changed files with 16 additions and 11 deletions.
5 changes: 3 additions & 2 deletions models/mobilenet.py
Expand Up @@ -59,7 +59,8 @@ def mobilenet_base(scope, inputs, conv_defs, is_training=True, reuse=None):
with tf.variable_scope(scope, inputs, reuse=reuse):
with slim.arg_scope([slim.conv2d, slim.separable_conv2d],
padding='SAME'):
with slim.arg_scope([slim.batch_norm], is_training=is_training):
with slim.arg_scope([slim.batch_norm], is_training=is_training,
decay=0.95):
net = inputs
for i, conv_def in enumerate(conv_defs):
end_point_base = 'Conv2d_%d' % i
Expand Down Expand Up @@ -124,7 +125,7 @@ def mobilenet_head(inputs, dropout_keep_prob=0.9, num_classes=2,
is_training=True):
# Define the network
net, end_points = mobilenet_base("MobileNetHead", inputs, HEAD_CONV_DEFS,
is_training)
is_training=is_training)

with tf.variable_scope('Logits'):
net = slim.avg_pool2d(net, [15, 15], padding='VALID',
Expand Down
22 changes: 13 additions & 9 deletions scripts/train.py
Expand Up @@ -12,6 +12,7 @@
os.environ['TF_CPP_MIN_LOG_LEVEL']='2'

import tensorflow as tf
slim = tf.contrib.slim
from tables import *
import numpy as np

Expand Down Expand Up @@ -164,18 +165,21 @@ def load_val_data(index):
filter_summ_op = tf.summary.image('filter',tf.slice(tf.transpose(kernel, perm=[3, 0, 1, 2]),begin=[0,0,0,0],size=[96,11,11,1]),max_outputs=IMAGE_VIZ_MAX_OUTPUTS)
activations_summ_op = tf.summary.image('activations',tf.slice(activations,begin=[0,0,0,0],size=[TRAIN_BATCH_SIZE,58,58,1]),max_outputs=IMAGE_VIZ_MAX_OUTPUTS)

#train op
# Define the train op
if args.optimizer == 'adadelta':
train_op = tf.train.AdadeltaOptimizer(learning_rate=variable_learning_rate).minimize(loss)
optimizer = tf.train.AdadeltaOptimizer(
learning_rate=variable_learning_rate)
elif args.optimizer == 'adam':
train_op = tf.train.AdamOptimizer(learning_rate=variable_learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=0.1,
use_locking=False,
name='Adam').minimize(loss)
optimizer = tf.train.AdamOptimizer(
learning_rate=variable_learning_rate,
beta1=0.9,
beta2=0.999,
epsilon=0.1,
use_locking=False,
name='Adam')
else:
train_op = tf.train.GradientDescentOptimizer(variable_learning_rate).minimize(loss)
optimizer = tf.train.GradientDescentOptimizer(variable_learning_rate)
train_op = slim.learning.create_train_op(loss, optimizer)

#for embeddings visualization
if embedding:
Expand Down

0 comments on commit 16c2c13

Please sign in to comment.