Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

batch normalization layer for data_format == 'channels_last' #1103

Open
2 tasks
edwardzcl opened this issue Oct 4, 2020 · 1 comment
Open
2 tasks

batch normalization layer for data_format == 'channels_last' #1103

edwardzcl opened this issue Oct 4, 2020 · 1 comment

Comments

@edwardzcl
Copy link

New Issue Checklist

Issue Description

according to batch normalization implement in TL, which can be find
at "https://github.com/tensorlayer/tensorlayer/blob/v2.2.0/tensorlayer/layers/normalization.py", the mean and var are computed with the whole inputs, not channel-wise inputs when the init are set as data_format='channels_last'.

You can refer to the 220 line for self.channel_axis = -1 if data_format == 'channels_last' else 1 and the 282 line for self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis], the self.axes

Reproducible Code

  • Which OS are you using ?
    python3, TL2.2 and TF2.0.
  • Please provide a reproducible code of your issue. Without any reproducible code, you will probably not receive any help.
class BatchNorm(Layer):
    """
    The :class:`BatchNorm` is a batch normalization layer for both fully-connected and convolution outputs.
    See ``tf.nn.batch_normalization`` and ``tf.nn.moments``.
    Parameters
    ----------
    decay : float
        A decay factor for `ExponentialMovingAverage`.
        Suggest to use a large value for large dataset.
    epsilon : float
        Eplison.
    act : activation function
        The activation function of this layer.
    is_train : boolean
        Is being used for training or inference.
    beta_init : initializer or None
        The initializer for initializing beta, if None, skip beta.
        Usually you should not skip beta unless you know what happened.
    gamma_init : initializer or None
        The initializer for initializing gamma, if None, skip gamma.
        When the batch normalization layer is use instead of 'biases', or the next layer is linear, this can be
        disabled since the scaling can be done by the next layer. see `Inception-ResNet-v2 <https://github.com/tensorflow/models/blob/master/research/slim/nets/inception_resnet_v2.py>`__
    moving_mean_init : initializer or None
        The initializer for initializing moving mean, if None, skip moving mean.
    moving_var_init : initializer or None
        The initializer for initializing moving var, if None, skip moving var.
    num_features: int
        Number of features for input tensor. Useful to build layer if using BatchNorm1d, BatchNorm2d or BatchNorm3d,
        but should be left as None if using BatchNorm. Default None.
    data_format : str
        channels_last 'channel_last' (default) or channels_first.
    name : None or str
        A unique layer name.
    Examples
    ---------
    With TensorLayer
    >>> net = tl.layers.Input([None, 50, 50, 32], name='input')
    >>> net = tl.layers.BatchNorm()(net)
    Notes
    -----
    The :class:`BatchNorm` is universally suitable for 3D/4D/5D input in static model, but should not be used
    in dynamic model where layer is built upon class initialization. So the argument 'num_features' should only be used
    for subclasses :class:`BatchNorm1d`, :class:`BatchNorm2d` and :class:`BatchNorm3d`. All the three subclasses are
    suitable under all kinds of conditions.
    References
    ----------
    - `Source <https://github.com/ry/tensorflow-resnet/blob/master/resnet.py>`__
    - `stackoverflow <http://stackoverflow.com/questions/38312668/how-does-one-do-inference-with-batch-normalization-with-tensor-flow>`__
    """

    def __init__(
            self,
            decay=0.9,
            epsilon=0.00001,
            act=None,
            is_train=False,
            beta_init=tl.initializers.zeros(),
            gamma_init=tl.initializers.random_normal(mean=1.0, stddev=0.002),
            moving_mean_init=tl.initializers.zeros(),
            moving_var_init=tl.initializers.zeros(),
            num_features=None,
            data_format='channels_last',
            name=None,
    ):
        super(BatchNorm, self).__init__(name=name, act=act)
        self.decay = decay
        self.epsilon = epsilon
        self.data_format = data_format
        self.beta_init = beta_init
        self.gamma_init = gamma_init
        self.moving_mean_init = moving_mean_init
        self.moving_var_init = moving_var_init
        self.num_features = num_features

        #self.channel_axis = -1 if data_format == 'channels_last' else 1
        ## add ##
        self.data_format = data_format
        self.axes = None

        if num_features is not None:
            self.build(None)
            self._built = True

        if self.decay < 0.0 or 1.0 < self.decay:
            raise ValueError("decay should be between 0 to 1")

        logging.info(
            "BatchNorm %s: decay: %f epsilon: %f act: %s is_train: %s" %
            (self.name, decay, epsilon, self.act.__name__ if self.act is not None else 'No Activation', is_train)
        )

             
                                      ## skip ##



    def forward(self, inputs):
        self._check_input_shape(inputs)
        ## add ##
        self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1

        if self.axes is None:
            self.axes = [i for i in range(len(inputs.shape)) if i != self.channel_axis]

        mean, var = tf.nn.moments(inputs, self.axes, keepdims=False)
        if self.is_train:
            # update moving_mean and moving_var
            self.moving_mean = moving_averages.assign_moving_average(
                self.moving_mean, mean, self.decay, zero_debias=False
            )
            self.moving_var = moving_averages.assign_moving_average(self.moving_var, var, self.decay, zero_debias=False)
            outputs = batch_normalization(inputs, mean, var, self.beta, self.gamma, self.epsilon, self.data_format)
        else:
            outputs = batch_normalization(
                inputs, self.moving_mean, self.moving_var, self.beta, self.gamma, self.epsilon, self.data_format
            )
        if self.act:
            outputs = self.act(outputs)
        return outputs

just delete line 220 code and add self.data_format = data_format in init, then add self.channel_axis = len(inputs.shape) - 1 if self.data_format == 'channels_last' else 1 in forward.

@Laicheng0830
Copy link
Member

Thanks! there is a problem with this code, we will fix it immediately.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants