Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easy to use batch norm layer. #1122

Closed
cesarsalgado opened this issue Feb 16, 2016 · 127 comments
Closed

Easy to use batch norm layer. #1122

cesarsalgado opened this issue Feb 16, 2016 · 127 comments
Assignees
Labels
stat:contribution welcome Status - Contributions welcome type:docs-bug Document issues

Comments

@cesarsalgado
Copy link
Contributor

Many non-experts are using the following code http://stackoverflow.com/questions/33949786/how-could-i-use-batch-normalization-in-tensorflow?answertab=votes#tab-top.

It would be nice to have an official batch norm layer given its importance in training DNNs.

@vincentvanhoucke
Copy link
Contributor

I'm working on some parts of that.

@vincentvanhoucke
Copy link
Contributor

There is now a batch_norm layer:

@Mahdizade
Copy link

Mahdizade commented Jun 18, 2016

I think some thing wrong with this layer. in training every thing is OK and loss decrease very good. but in testing I get zero accuracy.
By the way in testing when I use is_training=False, I get zero acc.
I know batch normalization behave different in train and test phase, as describe in How does batch normalization behave differently at training time and test time? - Quora. I think this implementation is unclear

@pawni
Copy link

pawni commented Jun 20, 2016

Same here, I have experienced some unexpected behavior with is_training=False. What is the correct way to change this flag? I am currently using a tf.cond because it does not take tf.placeholders by itself.

@ppwwyyxx
Copy link
Contributor

@pawni You have to use a Python boolean for is_training. It cannot be a tf.cond.

@pawni
Copy link

pawni commented Jun 20, 2016

@ppwwyyxx well I am doing tf.cond(placeholder, batch_norm(.., is_training = True), batch_norm(.., is_training = False)) or is one just supposed to do a batch_norm(.., is_training=variable) and change that outside of the graph when needed?

@ppwwyyxx
Copy link
Contributor

Oh I thought you were doing batch_norm(.., is_training=tf.cond(placeholder)), which is incorrect.
Your current way might have problems as well. You'll need to double check that the two batch_norm op you created share the same scope, otherwise they won't share the underlying mean/variance statistics.

To do this the reuse argument might help, but I'm not sure because I use my own version of bn layer.

@pawni
Copy link

pawni commented Jun 20, 2016

I am using the same scope and reuse=True. It seems to work sometimes but I am not too sure. It would be great if the layer could be added to the documentation with a short explanation how to best handle the change from training to test.

@vincentvanhoucke vincentvanhoucke removed their assignment Jun 21, 2016
@vincentvanhoucke
Copy link
Contributor

@sguada FYI

@sguada
Copy link
Member

sguada commented Jun 21, 2016

Currently batch_norm requires a python boolean, but we are working in adding the option of passing a Tensor.

@sguada
Copy link
Member

sguada commented Jun 21, 2016

@pawni If you don't want to worry about about updating moving_mean and moving_variance set updates_collections=None to make sure they are updated in place, otherwise you need to make sure the update_ops added to tf.GraphKeys.UPDATE_OPS are run during training.

@Mahdizade
Copy link

I think tensorflow need 2 hyper methods that change the model state, something like torch. change model state. I think it is very straightforward.

@brando90
Copy link

is there a small script with a very simple NN that shows what is the proper way of using this "official" BN layer? I'd really appreciate it.

@brando90
Copy link

brando90 commented Jul 11, 2016

sorry if this is a little repetitive, but it seems the API talks about BN in a different interface: https://www.tensorflow.org/versions/r0.9/api_docs/python/nn.html#batch_normalization

is that not the official way to use BN? I am confused on how to use it and the SO seems to be outdated and then there is a layer in a different link from the API, just how exactly does one do this? I am unclear if to go to SO or ask here.

@brando90
Copy link

brando90 commented Jul 12, 2016

sorry for the spamming, but what is wrong with just using something like this:

def standard_batch_norm(l, x, n_out, phase_train, scope='BN'):
    """
    Batch normalization on feedforward maps.
    Args:
        x:           Vector
        n_out:       integer, depth of input maps
        phase_train: boolean tf.Varialbe, true indicates training phase
        scope:       string, variable scope
    Return:
        normed:      batch-normalized maps
    """
    with tf.variable_scope(scope+l):
        #beta = tf.Variable(tf.constant(0.0, shape=[n_out], dtype=tf.float64 ), name='beta', trainable=True, dtype=tf.float64 )
        #gamma = tf.Variable(tf.constant(1.0, shape=[n_out],dtype=tf.float64 ), name='gamma', trainable=True, dtype=tf.float64 )
        init_beta = tf.constant(0.0, shape=[n_out], dtype=tf.float64)
        init_gamma = tf.constant(1.0, shape=[n_out],dtype=tf.float64)
        beta = tf.get_variable(name='beta'+l, dtype=tf.float64, initializer=init_beta, regularizer=None, trainable=True)
        gamma = tf.get_variable(name='gamma'+l, dtype=tf.float64, initializer=init_gamma, regularizer=None, trainable=True)
        batch_mean, batch_var = tf.nn.moments(x, [0], name='moments')
        ema = tf.train.ExponentialMovingAverage(decay=0.5)

        def mean_var_with_update():
            ema_apply_op = ema.apply([batch_mean, batch_var])
            with tf.control_dependencies([ema_apply_op]):
                return tf.identity(batch_mean), tf.identity(batch_var)

        mean, var = tf.cond(phase_train, mean_var_with_update, lambda: (ema.average(batch_mean), ema.average(batch_var)))
        normed = tf.nn.batch_normalization(x, mean, var, beta, gamma, 1e-3)
    return normed

then its simple to tell tensorflow which one to use with a feed dictionary as in:

feed_dict = {x: Xminibatch, y_: Yminibatch, phase_train: True}
sess.run(fetches=[merged,train_step], feed_dict=feed_dict)

since its unclear if the implementation will change, I wanted to give a suggestion (note its easy to extend to convolutions and stuff I just didn't paste that code).

@brando90
Copy link

@pawni @ppwwyyxx did you guys decide if you had to use reuse to true to solve the scoping issue?

@pawni
Copy link

pawni commented Jul 12, 2016

@brando90 currently I am doing something like:

def BatchNorm(inputT, is_training=True, scope=None):
    return tf.cond(isTraining,
                lambda: batch_norm(inputT, is_training=True,
                                   center=False, updates_collections=None, scope=scope),
                lambda: batch_norm(inputT, is_training=False,
                                   updates_collections=None, center=False, scope=scope, reuse = True))

However, I think that #3265 would basically want to implement it like this. A reference could be the dropout implementation here: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/contrib/layers/python/layers/layers.py#L433-L435

@sguada
Copy link
Member

sguada commented Jul 12, 2016

When the updates_collections=None then the updates happens in-place and it is easier to use a tf.cond() to allow is_training being a Tensor a bit more complicated is when the updates are delayed and the the update_ops are run later.
I will try to get the first part in soon.

@nmhkahn
Copy link

nmhkahn commented Jul 14, 2016

@brando90 @pawni he's code works good, but have to change like below

def BatchNorm(inputT, is_training=True, scope=None):
    # Note: is_training is tf.placeholder(tf.bool) type
    return tf.cond(is_training,  
                lambda: batch_norm(inputT, is_training=True,  
                                   center=False, updates_collections=None, scope=scope),  
                lambda: batch_norm(inputT, is_training=False,  
                                   updates_collections=None, center=False, scope=scope, reuse = True))  

And when run in training or test time,

# when training 
sess.run([opt, loss], feed_dict={x: bx, y: by, is_training=True})  

# when test 
sess.run([opt, loss], feed_dict={x: bx, y: by, is_training=False})  

This code works, but like #3265 says it will be great if tf.contrib.layers.batch_norm get is_training variable as a tf.plcaeholer.

@diegoAtAlpine
Copy link

diegoAtAlpine commented Jul 21, 2016

@nmhkahn @pawni thanks for the code snippets. They were very useful in adding batch normalization to my convolution network. Training seems to work very well. Testing is not. In some versions of the code training accuracies are much higher than testing accuracies, which probably mean I am not sharing batch normalization parameters. In other versions of the code I get "ValueError: Variable conv1/beta already exists, disallowed. Did you mean to set reuse=True in VarScope?" which seem to indicate that I am trying to relearn the parameter... when I was trying to reuse.

Can someone provide an example of how to call the "def BatchNorm" function during training and testing so that variable sharing happen correctly.

Thanks for any help.

UPDATE July 25, 2016:

@nmhkahn @pawni thanks for your comments. After taking a closer look at the code in contrib I realized what my problem was. During training and testing we are either updating or reusing four variables (beta, gamma, moving_mean and moving_variance). To make those unique I had to set a scope per layer. I did it like this:

conv1 = tf.nn.relu(batch_norm_layer(conv2d_stride2_valid(data, W_conv1) + b_conv1, train_phase, scope="conv1"))

where batch_norm_layer is similar to the examples from @nmhkahn @pawni, conv2d_stride2_valid is just a def to define a convolutional layer, and W_conv1 and b_conv1 are variables holding the weights and biases. I could probably remove the bias term because we are using batch normalization.

The net is working well now. I noticed after plotting accuracies in training and test mode that the testing accuracies start climbing after the training accuracies. In retrospect it make sense since we are collecting dataset statistics for testing. But it appeared as if I was doing something wrong during my initial tests. Thanks for your comments and making batch normalization available to the community.

@brando90
Copy link

@nmhkahn how is it different from pawni's suggestion?

@pawni
Copy link

pawni commented Jul 22, 2016

@brando90 I had a small error in my version which was fixed by nmhkahn (changing isTraining to is_training)

@diegoAtAlpine I found the same problems - not sure why this is the case though. However, the ValueError should be resolved by the code snippet. Not sure what you want to see how to call it as nmhkahn's examples seems to do the job?

@brando90
Copy link

brando90 commented Jul 22, 2016

@nmhkahn @pawni @ when you do:

sess.run([opt, loss], feed_dict={x: bx, y: by, is_training=True})

doesn't that mean that your using is_training as a placeholder? People have commented that they want is_training to be a placer holder but thats what I had for my version of it:

def batch_norm_layer(x,train_phase,scope_bn):

    bn_train = batch_norm(x, decay=0.999, center=True, scale=True,
    is_training=True,
    reuse=None, # is this right?
    trainable=True,
    scope=scope_bn)
    bn_inference = batch_norm(x, decay=0.999, center=True, scale=True,
    is_training=False,
    reuse=True, # is this right?
    trainable=True,
    scope=scope_bn)
    z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
    return z

is that not correct?

@sguada
Copy link
Member

sguada commented Jul 22, 2016

I have already extended tf.contrib.layers.batch_norm to allow passing a Tensor or a Placeholder for is_training. It will be merged in TF contrib soon.

Now available in
9da5fc8#diff-94bbcef0ec8a5cdef55f705e99c2b2ed

@brando90
Copy link

is it just me or does adding this BN layer noticeably slows down training of a single epoch?

tarasglek pushed a commit to tarasglek/tensorflow that referenced this issue Jun 20, 2017
@zmlmanly
Copy link

@sguada Hi, sguada, I have a problem.
The definition of contrib.layers.batch_norm in tensorflow:
def batch_norm(inputs,
decay=0.999,
center=True,
scale=False,
epsilon=0.001,
activation_fn=None,
param_initializers=None,
param_regularizers=None,
updates_collections=ops.GraphKeys.UPDATE_OPS,
is_training=True,
reuse=None,
variables_collections=None,
outputs_collections=None,
trainable=True,
batch_weights=None,
fused=False,
data_format=DATA_FORMAT_NHWC,
zero_debias_moving_mean=False,
scope=None,
renorm=False,
renorm_clipping=None,
renorm_decay=0.99):
scale: If True, multiply by gamma. If False, gamma is
not used. When the next layer is linear (also e.g. nn.relu), this can be
disabled since the scaling can be done by the next layer.

If I use tf.contrib.layers.batch_norm(input, scale=False) , the"scale =False" means whether the gamma is zero in "y = gamma*x+beta" while training. Thank you very much.

@ppwwyyxx
Copy link
Contributor

When scale=False, gamma is a constant 1.

@zmlmanly
Copy link

@ppwwyyxx Thank you very much for your help. I use tf.contrib.layers.batch_norm(input, scale=False) in Tensorflow, and now I am convering the batchnorm of Tensorflow to Caffe. How to set the param of BatchNormLayer and ScaleLayer in Caffe?
Thank you very much.

@tano297
Copy link
Contributor

tano297 commented Jul 20, 2017

@MisayaZ I was having the same behavior using Batchnorm with a placeholder for "is_training". I see in the trace that the moments are being calculated even at test time, so I decided to go into the source code and I found this:

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
        # here it defines the moments

It looks like when "is_training" is a variable or a placeholder the moments get defined and also get calculates them at runtime, even when you set the placeholder to "False". I would have preferred to leave it as a placeholder because this way I can do periodic testing during training without redefining the graph, but I decided to use it as a constant and define different behaviors for train vs test, and now the moments are not calculated at test time.

@MisayaZ
Copy link

MisayaZ commented Jul 21, 2017

@tano297 Thank you. I now also use 'is_training' as a constant. Leave it as a placeholder and do periodic testing will change the value of moving mean and moving variance. And the inference time will be longer for it will calculate the mean and variance of the inputs and update the moving mean and moving variance. The right way to do testing is to define different behaviors for train and test as you mentioned.

@abred
Copy link

abred commented Jul 21, 2017

@tano297 @MisayaZ
but doesn't the "smart_cond" in

is_training_value = utils.constant_value(is_training)
need_updates = is_training_value is None or is_training_value
if need_updates:
  ...
  outputs = utils.smart_cond(is_training, _force_updates, no_updates)

make sure that the updates are only calculated and applied if is_training evaluates to True?

@tano297
Copy link
Contributor

tano297 commented Jul 21, 2017

@abred Yes indeed, but you are referring to line 391, where it does the update of the moving average within _fused_batch_norm():

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `need_updates` will be true.
    is_training_value = utils.constant_value(is_training)
    need_updates = is_training_value is None or is_training_value
    if need_updates:
        ...
        outputs = utils.smart_cond(is_training, _force_updates, no_updates)
        ...

I am talking about line 753 within batch_norm():

    # If `is_training` doesn't have a constant value, because it is a `Tensor`,
    # a `Variable` or `Placeholder` then is_training_value will be None and
    # `needs_moments` will be true.
    is_training_value = utils.constant_value(is_training)
    need_moments = is_training_value is None or is_training_value
    if need_moments:
        ...
        mean, variance = utils.smart_cond(is_training,
                                          _force_updates,
                                          moving_vars_fn) 
        ...

The smart condition in that case (as far as I am concerned) decides wether or not to update the moving averages, but the moments still get calculated.

@abred
Copy link

abred commented Jul 21, 2017

@tano297 you right about that, I was in the wrong place, but still:
line 755-770 calculate the moments, but the moments are only used in _force_updates which is only executed if is_training evaluates to True, aren't they?
And thus

mean, variance = utils.smart_cond(is_training, _force_updates, moving_vars_fn) 

should be equivalent to line 804:

mean, variance = moving_mean, moving_variance

if is_training evalutes to False and thus the "moments"-part of the graph is never used and thus shouldn't be executed

but I haven't tested, so I might be wrong about that :)

@MisayaZ
Copy link

MisayaZ commented Jul 24, 2017

@tano297 @abred you right. The moving mean and moving variance are changed when I used batchnorm like this:

def batch_norm_layer(self, x,train_phase, scope_bn):
        bn_train = batch_norm(x, decay=0.9, center=False, scale=True,
        updates_collections=None,
        is_training=True,
        reuse=None,
        variables_collections= [UPDATE_OPS_COLLECTION],
        trainable=True,
        scope=scope_bn)
        bn_inference = batch_norm(x, decay=0.9, center=False, scale=True,
        updates_collections=None,
        is_training=False,
        reuse=True,
        variables_collections= [UPDATE_OPS_COLLECTION],
        trainable=True,
        scope=scope_bn)
        z = tf.cond(train_phase, lambda: bn_train, lambda: bn_inference)
        return z

If you use like following:

z = batch_norm(x, decay=0.9, center=False, scale=True, updates_collections=None, 
                         is_training=train_phase, scope=scope_bn)

The moving mean and moving variance will not be changed during test, but the speed is very slow.

@tyshiwo
Copy link

tyshiwo commented Aug 4, 2017

Hi @zhongyuk ,

I also met the problem that I could get good results when using is_training=True for both training and inference, but get bad results when setting is_training=False during inference (worse than the case using is_training=True). According to your analysis, If I understand correctly, by simply setting decay=0.9 in BN can solve this problem. Am I right?

BTW, do I need to retrain the model using decay=0.9 from scratch? Or resuming training from the checkpoint (i.e., trained when decay=0.999) is also ok?

Thanks!

@tyshiwo
Copy link

tyshiwo commented Aug 5, 2017

@nmduc @davek44

Hi, I also met the problem that I could get good results when using is_training=True for both training and inference, but get bad results when setting is_training=False during inference (worse than the case using is_training=True). Have you guys solved this problem? Thanks!

@nmduc
Copy link

nmduc commented Aug 5, 2017

@tyshiwo I just set decay=0.9 for batch_norm and it works well so far.

@ghost
Copy link

ghost commented Aug 10, 2017

I was confused after all these comments on how to properly use Batch Norm: So here is what I have. Please correct me if I'm wrong.

batch_norm = tf.contrib.layers.batch_norm(conv, center=True, scale=True, reuse=phase_train_py, scope='bn', is_training=is_training)

where phase_train_py is a python boolean variable and is_training is a placeholder taking a boolean variable. I guess using tf.cond is wrong, otherwise would did the function came with a boolean parameters. In other words, if tf.cond is true, then we should a batch_norm function for training and another one for testing. So, developers allow us to change these boolean variables in order to change the behavior of the function. So What I am doing is: setting phase_train_py to False while training while is_training to True. And the opposite while Testing. Since we can only change tensors or placeholders with sess.run, I changed phase_train_py intentionally before running the graph. Ex:

if condition: phase_train_py = False sess.run(to_run_list, feed_dict={phase_train: True}) else: phase_train_py = True sess.run(to_run_list, feed_dict={phase_train: False})

@zhimengfan1990
Copy link

zhimengfan1990 commented Sep 16, 2017

+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++
MAYBE YOU NEED READ THIS
+++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++

It seems there are still problems with TF v1.3. I'm sure I note the following details, but still failed to use the official tf.contrib.layers.batch_norm, with is_training=False during evaluation(but when I keep is_training=True unchanged during evaluation, it is ok):
1. decay, exponential moving average is actually alpha filter in signal processing, the time to converge is approximately 1/(1-decay) steps of train. For decay=0.999, you need 1/0.001=1000 steps to converge. So set the appropriate decay for your training step numbers.
2. using placeholder to switch between train and test evaluation
3. use updates_collections=None if you don't want to add control dependencies of update op to train_op
4. set reuse to appropriate value.

It seems the only way to use the official batch_norm is to build two graphs, one for train and one for evaluation, with is_training=True and is_training=False, respectively. In this way, you don't need to switch dynamically between train and evaluation. But this is a stupid way since you need to build more than one graph.

Finally, I write a moving average by myself, and I find it worked! It's as follows(based on code on the web and modified by myself)

def bn_layer(x, scope, is_training, epsilon=0.001, decay=0.99, reuse=None):
    """
    Performs a batch normalization layer

    Args:
        x: input tensor
        scope: scope name
        is_training: python boolean value
        epsilon: the variance epsilon - a small float number to avoid dividing by 0
        decay: the moving average decay

    Returns:
        The ops of a batch normalization layer
    """
    with tf.variable_scope(scope, reuse=reuse):
        shape = x.get_shape().as_list()
        # gamma: a trainable scale factor
        gamma = tf.get_variable("gamma", shape[-1], initializer=tf.constant_initializer(1.0), trainable=True)
        # beta: a trainable shift value
        beta = tf.get_variable("beta", shape[-1], initializer=tf.constant_initializer(0.0), trainable=True)
        moving_avg = tf.get_variable("moving_avg", shape[-1], initializer=tf.constant_initializer(0.0), trainable=False)
        moving_var = tf.get_variable("moving_var", shape[-1], initializer=tf.constant_initializer(1.0), trainable=False)
        if is_training:
            # tf.nn.moments == Calculate the mean and the variance of the tensor x
            avg, var = tf.nn.moments(x, np.arange(len(shape)-1), keep_dims=True)
            avg=tf.reshape(avg, [avg.shape.as_list()[-1]])
            var=tf.reshape(var, [var.shape.as_list()[-1]])
            #update_moving_avg = moving_averages.assign_moving_average(moving_avg, avg, decay)
            update_moving_avg=tf.assign(moving_avg, moving_avg*decay+avg*(1-decay))
            #update_moving_var = moving_averages.assign_moving_average(moving_var, var, decay)
            update_moving_var=tf.assign(moving_var, moving_var*decay+var*(1-decay))
            control_inputs = [update_moving_avg, update_moving_var]
        else:
            avg = moving_avg
            var = moving_var
            control_inputs = []
        with tf.control_dependencies(control_inputs):
            output = tf.nn.batch_normalization(x, avg, var, offset=beta, scale=gamma, variance_epsilon=epsilon)

    return output


def bn_layer_top(x, scope, is_training, epsilon=0.001, decay=0.99):
    """
    Returns a batch normalization layer that automatically switch between train and test phases based on the 
    tensor is_training

    Args:
        x: input tensor
        scope: scope name
        is_training: boolean tensor or variable
        epsilon: epsilon parameter - see batch_norm_layer
        decay: epsilon parameter - see batch_norm_layer

    Returns:
        The correct batch normalization layer based on the value of is_training
    """
    #assert isinstance(is_training, (ops.Tensor, variables.Variable)) and is_training.dtype == tf.bool

    return tf.cond(
        is_training,
        lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=True, reuse=None),
        lambda: bn_layer(x=x, scope=scope, epsilon=epsilon, decay=decay, is_training=False, reuse=True),
    )

Just use the bn_layer_top function during building a graph, the is_training parameter is a tf.placeholder
. Then you are free to switch the placeholder to True during train and False during evaluation, with feed_dict.

Hope it helps the community.

aribrill added a commit to ctlearn-project/ctlearn that referenced this issue Nov 10, 2017
This commit solves the bug observed in all previous versions of this
code in which validation loss/accuracy are approximately/exactly
constant for every epoch and all validation predictions are of the
same class.

The problem was due to incorrect implementation of batch normalization
in two ways. First, tensorflow does not automatically collect the update
ops for updating the moving_mean and moving_variance. This is now being
done by using slim.learning.create_train_op() instead of native
tf.train.Optimizer().minimize() to create the train op. Second, the
decay parameter has been decreased from the default 0.999 to 0.95, as
with too high of a value batch_norm takes too long to converge on a
small dataset.

For more information, see
tensorflow/tensorflow#1122.
@tasx0823
Copy link

tasx0823 commented Dec 7, 2017

When you use slim.batch_norm,be sure to use "slim.learning.create_train_op" instead of "tf.train.GradientDecentOptimizer(lr).minimize(loss)" or other optimizer. Try it to see if it works!

@ZahlGraf
Copy link

@vincentvanhoucke You wrote in another post in this thread:

The slim batch_norm wrapper normalizes over the last dimension of your input tensor. So if it's a 2D input tensor coming from a fully connected layer, it normalizes over batch, and thus performs per-activation normalization. If it's a 4D tensor coming from a convolution, it will normalize over the three first dimensions (batch, width, depth), and thus perform per-feature normalization. @sguada maybe forth being a bit more descriptive about this.

Do you mean with "slim batch_norm wrapper" the function tf.contrib.layers.batch_norm? If so, I would suggest to add this information to the documentation text of this function. Thus it gets very clear, that this function performs the batch normalization exactly like described in the paper... for both FC-Layer and Conv2D-Layer. At the moment there is only the text "Can be used as a normalizer function for conv2d and fully_connected.", where it is not clear if this is related to the normalization axis topic.

@vincentvanhoucke
Copy link
Contributor

@ZahlGraf I'll happily consider a PR that clarifies the documentation. We've been at this for so long that I no longer have a good sense of what's obvious or not, and would welcome clarifying documentation for someone with a fresh perspective on the topic.

@Netzeband
Copy link
Contributor

@vincentvanhoucke
I created a PR with a more detailed description, mainly based on your statement in this thread:
#15653

@tensorflowbutler
Copy link
Member

Please remove the assignee, as this issue is inviting external contributions. Otherwise, remove the contributions welcome label. Thank you.

@tensorflowbutler
Copy link
Member

Please remove the assignee, as this issue is inviting external contributions. Otherwise, remove the contributions welcome label. Thank you.

@annarev
Copy link
Contributor

annarev commented Feb 7, 2018

Closing this bug since the original request to add a batch norm layer has been addressed. Some of the more recent issues with documentation seem to have their own PRs
If you see any issue with batch_norm, please either ask a question on StackOverflow or open another issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
stat:contribution welcome Status - Contributions welcome type:docs-bug Document issues
Projects
None yet
Development

No branches or pull requests