Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training does not converge when DoReFa is used as a kernel quantizer #487

Open
bferrarini opened this issue May 9, 2020 · 6 comments
Open
Assignees

Comments

@bferrarini
Copy link

Describe the bug

Hi,
I tried to use DoReFa quantiser to train a simple model for CIFAR10, but training failed to converge:

10000/1 - 1s 96us/sample - loss: 2.3026 - accuracy: 0.1000
Test loss: 2.3025851249694824
Test accuracy: 0.1

I found that the problem only occurs when DoReFa quantizer is used on the kernel (kernel_quantizer)

This code reproduces the problem. DoReFa is used for both activations and Weights.

from __future__ import print_function
import tensorflow as tf
from tensorflow.keras.datasets import cifar10
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Activation, Flatten, BatchNormalization
from tensorflow.keras.layers import MaxPooling2D
import larq as lq
from larq.layers import QuantConv2D, QuantDense
import time

import os
from larq.quantizers import DoReFaQuantizer

batch_size = 128
num_classes = 10
epochs = 30
num_predictions = 20
save_dir = os.path.join(os.getcwd(), 'saved_models')

clip_value = 1.0
model_name = 'keras_cifar10_trained_model_test_DOREFA'

TEST = False

(x_train, y_train), (x_test, y_test) = cifar10.load_data()
print('x_train shape:', x_train.shape)
print(x_train.shape[0], 'train samples')
print(x_test.shape[0], 'test samples')

y_train = tf.keras.utils.to_categorical(y_train, num_classes)
y_test = tf.keras.utils.to_categorical(y_test, num_classes)


model = Sequential()
model.add(QuantConv2D(32, (3, 3), padding='same',
                      input_quantizer = None,
                      kernel_quantizer = DoReFaQuantizer(),
                      kernel_constraint = lq.constraints.WeightClip(clip_value=1),
                      use_bias = True,
                 input_shape=x_train.shape[1:]))
model.add(BatchNormalization(momentum=0.9))

model.add(QuantConv2D(32, (3, 3), padding='valid',
                      input_quantizer = DoReFaQuantizer(),
                      kernel_constraint=lq.constraints.WeightClip(clip_value=1),
                      use_bias = False,
                      kernel_quantizer = DoReFaQuantizer()))      
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(BatchNormalization(momentum=0.9))

model.add(QuantConv2D(64, (3, 3), padding='same',
                      input_quantizer = DoReFaQuantizer(),
                      kernel_constraint=lq.constraints.WeightClip(clip_value=1),
                      kernel_quantizer = DoReFaQuantizer()))
model.add(BatchNormalization(momentum=0.9))

model.add(QuantConv2D(64, (3, 3), padding='valid',
                      input_quantizer = DoReFaQuantizer(),
                      kernel_constraint=lq.constraints.WeightClip(clip_value=1),
                      use_bias = False,
                      kernel_quantizer = DoReFaQuantizer()))
model.add(MaxPooling2D(pool_size=(2, 2)))
model.add(BatchNormalization(momentum=0.9))

model.add(Flatten())

model.add(QuantDense(512, input_quantizer = DoReFaQuantizer(), 
                     kernel_quantizer = DoReFaQuantizer(), 
                     use_bias = False,
                     kernel_constraint=lq.constraints.WeightClip(clip_value=1)))
model.add(BatchNormalization(momentum=0.9))

model.add(QuantDense(num_classes, input_quantizer = DoReFaQuantizer(), 
                     kernel_quantizer = DoReFaQuantizer(),
                     use_bias = False,
                     kernel_constraint=lq.constraints.WeightClip(clip_value=1)))
model.add(Activation('softmax'))

opt = tf.keras.optimizers.RMSprop(learning_rate=0.0001*5, decay=1e-6)

model.compile(loss='categorical_crossentropy',
              optimizer=opt,
              metrics=['accuracy'])

x_train = x_train.astype('float32')
x_test = x_test.astype('float32')
x_train /= 255
x_test /= 255

start = time.time()

if not os.path.isdir(save_dir):
    os.makedirs(save_dir)
model_path = os.path.join(save_dir, model_name)

if not TEST:

    print('Not using data augmentation.')
    model.fit(x_train, y_train,
              batch_size=batch_size,
              epochs=epochs,
              validation_data=(x_test, y_test),
              shuffle=True)
    
    
    model.save_weights(model_path)
    print('Saved trained model at %s ' % model_path)

else:
    
    model.load_weights(model_path)

scores = model.evaluate(x_test, y_test, verbose=1)
print('Test loss:', scores[0])
print('Test accuracy:', scores[1])

stop = time.time()

print("Training done in {:0.2f} seconds for {} epochs.".format(None, stop-start, epochs))

Expected behavior

Training converge for any precision set in DoReFa (k_bit parameter)

Environment

TensorFlow version: 2.0.0
Larq version: 0.9.4

@AdamHillier
Copy link
Contributor

@bferrarini does your code work if you replace DoReFa with some other quantiser, such as SteSign?

@bferrarini
Copy link
Author

bferrarini commented May 11, 2020

Hi @AdamHillier ,

I ran the same code using SetSign on the weights and it worked.

10000/1 1s 93us/sample - loss: 1.3581 - accuracy: 0.6566
Test loss: 1.3786100630760192
Test accuracy: 0.6566
Training done in 170.44 seconds for 30 epochs.

Accordingly with my understanding, DoReFa should clip the gradient as expected (link)
I implemented a version of DoReFa with gradient clipping. Here the results

DoReFa + MyDoReFa w/ clipping (2 bits):  converge slowly, but converge.
10000/ 101us/sample - loss: 2.1806 - accuracy: 0.1226
Test loss: 2.2232713317871093
Test accuracy: 0.1226
Training done in 178.93 seconds for 30 epochs
DoReFa + MyDoReFa w/ clipping (1 bit): converge 
10000/ 1s 97us/sample - loss: 1.3921 - accuracy: 0.4745
Test loss: 1.5290546701431273
Test accuracy: 0.4745
Training done in 178.45 seconds for 30 epochs.

Here the clipped DoReFa quantizer I implemented for the tests

@utils.register_keras_custom_object
class MyDoReFaQuantizer(BaseQuantizer):
    
    precision = None

    def __init__(self, k_bit: int = 2, **kwargs):
        self.precision = k_bit
        super().__init__(**kwargs)

    def call(self, inputs):
        inputs = tf.clip_by_value(inputs, 0, 1.0)

        @tf.custom_gradient
        def _k_bit_with_identity_grad(x):
            def grad(dy):
                return _clipped_gradient(x, dy, clip_value=1.0)
            n = 2 ** self.precision - 1
            return (tf.round(x * n)) / n, grad

        outputs = _k_bit_with_identity_grad(inputs)
        return super().call(outputs)

    def get_config(self):
        return {**super().get_config(), "k_bit": self.precision}

Kind regards,

Bruno

@jneeven jneeven self-assigned this May 13, 2020
@jneeven
Copy link
Contributor

jneeven commented May 13, 2020

Hi @bferrarini ,

Thanks for opening the issue and providing the problematic code! I have successfully reproduced the problem; the model indeed does not seem to train when using the DoReFaQuantizer for the weights, and does train when using SteSign instead. You are correct that the DoReFaQuantizer should clip the gradient, and in fact it does. The plot in our documentation you linked to is actually generated using the quantizer, so it is definitely doing the right thing. Your adapted version, MyDoReFaQuantizer, does exactly the same thing; the gradient was already clipped (because the inputs are clipped in the forward pass, and therefore so is the gradient), so adding another clip doesn't make a difference. To check this out for yourself, you can use the following code:

import numpy as np
import matplotlib.pyplot as plt

def calculate_activation(function, x):
    tf_x = tf.Variable(x)
    with tf.GradientTape() as tape:
        activation = function(tf_x)
    return activation.numpy(), tape.gradient(activation, tf_x).numpy()

def plot_quantizer(quantizer):
    x = np.linspace(-2, 2, 500).astype(np.float32)
    y, dy = calculate_activation(quantizer, x)
    plt.subplot(121)
    plt.grid()
    plt.plot(x, y)
    plt.subplot(122)
    plt.grid()
    plt.plot(x, dy)
    plt.show()

plot_quantizer(MyDoReFaQuantizer())

This will plot the following figure:
image
Which is indeed the same as the one of the normal DoReFaQuantizer.

Nevertheless, I'm surprised that the model does not train. Even when I use your MyDoReFaQuantizer and k_bit = 1 or k_bit=2, it never obtains more than 10% accuracy.
Could you please send me the code that allowed you to obtain 47% accuracy? I suspect you have accidentally changed something else, which has somehow resolved the problem.

I also had another look at the DoReFa paper, and found that for k_bit > 1, they actually use a different quantization formula for the weights than for the activations (i.e. different than the one currently in Larq):
image
Perhaps this explains why the DoReFaQuantizer doesn't work well when using it on weights with k_bit = 2, but I'd still expect it to perform better than random (i.e. obtain an accuracy > 10%)...

@bferrarini
Copy link
Author

Hi @jneeven,

You are right in both of the cases.

  1. MyDoReFa is equivalent to DoReFa (many thanks for the plot_quantize function, it is handy).
  2. I confirm I cannot reproduce the experiment that obtained 47% accuracy. As supposed by you, I probably made some mistakes in the first run. Possibly, I did not use MyDoReFa quantizer in some part of the code.

Kind Regards,

Bruno

@jneeven
Copy link
Contributor

jneeven commented May 19, 2020

I had another look at the DoReFa paper and have concluded that the issues here stem from the fact that they use a different quantization formula for the weights than for the activations (as I mentioned above). This probably leads to gradient issues, preventing your model from training altogether. In case you want to use binary weights, there is no point in quantizing them with DoReFa, as even the authors themselves just resort to SteSign in that case:
image

If you want to use DoReFa for weights where k_bit >= 2, you will need to use the formula described in the paper. A hacky implementation could look somewhat like this (I have not verified that everything is correct):

import larq as lq
import tensorflow as tf


def _clipped_gradient(x, dy, clip_min, clip_max):
    """Calculate `clipped_gradient * dy`."""

    zeros = tf.zeros_like(dy)
    mask = tf.math.logical_and(tf.math.greater(x, clip_min), tf.math.less(x, clip_max))
    return tf.where(mask, dy, zeros)


class CustomDoReFaQuantizer(lq.quantizers.DoReFaQuantizer):
    r"""DoReFa with different gradient clipping threshold."""

    def __init__(self, k_bit: int = 2, mode="activations", **kwargs):
        self.precision = k_bit
        self.n = 2 ** self.precision - 1

        if mode not in ["weights", "activations"]:
            raise ValueError(f"DoReFa received unknown mode: {mode}")
        self.mode = mode

        super().__init__(**kwargs)

    def call(self, inputs):
        if self.mode == "activations":
            @tf.custom_gradient
            def quantize_k(x):
                x = tf.clip_by_value(x, 0, 1.0)
                return (
                    tf.round(x * self.n) / self.n,
                    lambda dy: _clipped_gradient(inputs, dy, 0, 1.0),
                )
            return quantize_k(inputs)
        else:
            @tf.custom_gradient
            def quantize_k(x):
                return (
                    tf.round(x * self.n) / self.n,
                    lambda dy: x
                )

            tan = tf.math.tanh(inputs)
            fraction = tan / (2 * tf.math.reduce_max(tf.math.abs(tan)))
            print(inputs.shape, tan.shape, quantize_k(fraction + 0.5).shape)
            return 2.0 * quantize_k(fraction + 0.5) - 1

I tried this with your code and although the accuracy is still very low (I think it was around 14%), it does at least train.

That being said, DoReFa is somewhat outdated at this point, and a much more interesting quantizer to use would be LSQ from the paper Learned Step Size Quantization. It does basically the same as DoReFa (the scalar multiplied can be fused with the batchnorm), but has much better gradients. A hacky implementation of that (which I already had lying around for some toy examples) could like this:

import numpy as np
import tensorflow as tf
from larq import utils


@tf.custom_gradient
def scaled_gradient(x: tf.Tensor, scale: float = 1.0) -> tf.Tensor:
    def grad(dy):
        # We don't return a gradient for `scale` as it isn't trainable
        return (dy * scale, 0.0)

    return x, grad


@utils.register_alias("lsq")
@utils.register_keras_custom_object
class LSQ(tf.keras.layers.Layer):
    r"""Instantiates a serializable k_bit quantizer as in the LSQ paper.

    # Arguments
    k_bit: number of bits for the quantization.
    mode: either "signed" or "unsigned", reflects the activation quantization scheme to
        use. When using this for weights, use mode "weights" instead.
    metrics: An array of metrics to add to the layer. If `None` the metrics set in
        `larq.context.metrics_scope` are used. Currently only the `flip_ratio` metric is
        available.

    # Returns
    Quantization function

    # References
    - [Learned Step Size Quantization](https://arxiv.org/abs/1902.08153)
    """
    precision = None

    def __init__(self, k_bit: int = 2, mode="unsigned", **kwargs):
        self.precision = k_bit
        self.mode = mode

        if mode == "unsigned":
            self.q_n = 0.00
            self.q_p = float(2 ** self.precision - 1)
        elif mode in ["signed", "weights"]:
            self.q_p = float(2 ** (self.precision - 1)) - 1

            # For signed, we can use the full signed range, e.g. [-2, 1]
            if mode == "signed":
                self.q_n = -float(2 ** (self.precision - 1))
            # For weights, we use a symmetric range, e.g. [-1, 1]
            else:
                self.q_n = -float(2 ** (self.precision - 1) - 1)

        else:
            raise ValueError(f"LSQ received unknown mode: {mode}")

        super().__init__(**kwargs)

    def build(self, input_shape):
        self.s = self.add_weight(
            name="s",
            initializer="ones",
            trainable=True,
            aggregation=tf.VariableAggregation.MEAN,
        )
        self._initialized = self.add_weight(
            name="initialized",
            initializer="zeros",
            dtype=tf.dtypes.bool,
            trainable=False,
            aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA,
        )

        # Assuming that by num_features they mean all the individual pixels.
        # You can also try the number of feature maps instead.
        self.g = float(1.0 / np.sqrt(np.prod(input_shape[1:]) * self.q_p))

        super().build(input_shape)

    def call(self, inputs):
        # Calculate initial value for the scale using the first batch
        self.add_update(
            self.s.assign(
                tf.cond(
                    self._initialized,
                    lambda: self.s,  # If already initialized, just use current value
                    # Otherwise, use the value below as initialization
                    lambda: (2.0 * tf.reduce_mean(tf.math.abs(inputs)))
                    / tf.math.sqrt(self.q_p),
                )
            )
        )
        self.add_update(self._initialized.assign(True))
        s = scaled_gradient(self.s, self.g)
        rescaled_inputs = inputs / s
        clipped_inputs = tf.clip_by_value(rescaled_inputs, self.q_n, self.q_p)

        @tf.custom_gradient
        def _round_ste(x):
            return tf.round(x), lambda dy: dy

        return _round_ste(clipped_inputs) * s

    def get_config(self):
        return {**super().get_config(), "k_bit": self.precision, "mode": self.mode}

I can't make any guarantees regarding LSQ or this code either, but it's definitely worth a try.
Thank you for pointing out this issue, I will update the DoReFa documentation to make clear it is intended only for activations in its current state.

@bferrarini
Copy link
Author

Many thanks for the effort in addressing the issue and for providing the code for LSQ. I will try it.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants