Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Reproduce SE-VGG16 results #89

Open
eslambakr opened this issue Feb 5, 2020 · 0 comments
Open

Reproduce SE-VGG16 results #89

eslambakr opened this issue Feb 5, 2020 · 0 comments

Comments

@eslambakr
Copy link

Hello @hujie-frank

First of all, thanks for sharing your amazing work.

I am trying to reproduce your results using VGG-16 but on cifar10 & cifar100, but unfortunately, I couldn't increase the accuracy.
I run two experiments the first one is the base line in which I trained the original VGG-16 without adding the SE block and the second experiment I added the SE block, I expect the validation accuracy to increase but unfortunately it doesn't.

Here you are my training details:
1- Learning rate starts from 1e-4 and decays to 1e-5.
2- I resized the input size to be 224.
3- I used ADAM optimizer.
4- I construct the original VGG as follows
conv = Conv2D(channels, kernel_size=kernel_size, padding='same', activation='relu', use_bias=False , kernel_regularizer=regularizers.l2(0.0005), name="conv_" + str(block_number))(input)
conv = BatchNormalization()(conv)
conv = Dropout(rate=drop)(conv)

5- I construct a second version of VGG
conv = Conv2D(channels, kernel_size=kernel_size, padding='same', activation='relu', use_bias=False , kernel_regularizer=regularizers.l2(0.0005), name="conv_" + str(block_number))(input)
conv = Dropout(rate=drop)(conv)
conv = BatchNormalization()(conv)

6- I construct the SE-VGG as follows
conv = Conv2D(channels, kernel_size=kernel_size, padding='same', activation='relu', use_bias=False , kernel_regularizer=regularizers.l2(0.0005), name="conv_" + str(block_number))(input)
conv = Dropout(rate=drop)(conv)
conv = SE_Layer(name=str(block_number), input_layer=conv_layer)(conv)
conv = BatchNormalization()(conv)

and here you are my implementation for the SE_Layer:

class SE_Layer(Layer):
    def __init__(self, input_layer, **kwargs):
        #self.scaling = scaling
        self.input_layer = input_layer
        self.prev_layer = None
        self.ratio = 8
        self.x = None
        super(SE_Layer, self).__init__(**kwargs)
def build(self, input_shape):
    self.dense_1_weights = self.add_weight(name='dense_1_weights',
                                           shape=(self.input_layer.output_shape[-1], int(self.input_layer.output_shape[-1]/self.ratio)),
                                           initializer='he_normal',
                                           trainable=True)
    self.dense_2_weights = self.add_weight(name='dense_2_weights',
                                           shape=(int(self.input_layer.output_shape[-1]/self.ratio), self.input_layer.output_shape[-1]),
                                           initializer='he_normal',
                                           trainable=True)
    super(SE_Layer, self).build(input_shape)

def call(self, conv):
    c = int(conv.shape[-1])
    x = conv
    x = GlobalAveragePooling2D(data_format='channels_last')(x)
    x = K.mean(x, axis=[0], keepdims=True)
    x = _normalize(x)
    x = Reshape([1, 1, c], name=self.name + "_reshape")(x)
    x = tf.matmul(x, self.dense_1_weights)
    x = relu(x)
    x = tf.matmul(x, self.dense_2_weights)
    x = sigmoid(x)
    self.x = x
    y = multiply([conv, x], name=self.name + "_mul")
    return y

def get_scaling(self):
    return self.x

def compute_output_shape(self, input_shape):
    return input_shape`

Here you are my results, the 3 experiments (changing the architecture as described in steps 4,5 and 6) achieved the same accuracy (94.1 %)
The training stops when the model is over-fitting as I made a patient for 20 epoch to guarantee the model is over-fitting.

Thanks in advance, I hope you could help me to reproduce your results.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant