Skip to content

Keras implementation of mlp-mixer, ResMLP, gmlp. imagenet/imagenet21k weights reloaded.

License

Notifications You must be signed in to change notification settings

leondgarse/keras_mlp

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

30 Commits
 
 
 
 
 
 
 
 

Repository files navigation

Keras_mlp


Usage

  • This repo can be installed as a pip package.
    pip install -U git+https://github.com/leondgarse/keras_mlp
    or just git clone it.
    git clone https://github.com/leondgarse/keras_mlp.git
    cd keras_mlp && pip install .
  • Basic usage
    import keras_mlp
    # Will download and load `imagenet` pretrained weights.
    # Model weight is loaded with `by_name=True, skip_mismatch=True`.
    mm = keras_mlp.MLPMixerB16(num_classes=1000, pretrained="imagenet")
    
    # Run prediction
    import tensorflow as tf
    from tensorflow import keras
    from skimage.data import chelsea # Chelsea the cat
    imm = keras.applications.imagenet_utils.preprocess_input(chelsea(), mode='tf') # model="tf" or "torch"
    pred = mm(tf.expand_dims(tf.image.resize(imm, mm.input_shape[1:3]), 0)).numpy()
    print(keras.applications.imagenet_utils.decode_predictions(pred)[0])
    # [('n02124075', 'Egyptian_cat', 0.9568315), ('n02123045', 'tabby', 0.017994137), ...]
    For "imagenet21k" pre-trained models, actual num_classes is 21843.
  • Exclude model top layers by set num_classes=0.
    import keras_mlp
    mm = keras_mlp.ResMLP_B24(num_classes=0, pretrained="imagenet22k")
    print(mm.output_shape)
    # (None, 784, 768)
    
    mm.save('resmlp_b24_imagenet22k-notop.h5')

MLP mixer

  • PDF 2105.01601 MLP-Mixer: An all-MLP Architecture for Vision.

  • Github google-research/vision_transformer.

  • Models Top1 Acc is Pre-trained on JFT-300M model accuray on ImageNet 1K from paper.

    Model Params Top1 Acc ImageNet Imagenet21k ImageNet SAM
    MLPMixerS32 19.1M 68.70
    MLPMixerS16 18.5M 73.83
    MLPMixerB32 60.3M 75.53 b32_imagenet_sam.h5
    MLPMixerB16 59.9M 80.00 b16_imagenet.h5 b16_imagenet21k.h5 b16_imagenet_sam.h5
    MLPMixerL32 206.9M 80.67
    MLPMixerL16 208.2M 84.82 l16_imagenet.h5 l16_imagenet21k.h5
    - input 448 208.2M 86.78
    MLPMixerH14 432.3M 86.32
    - input 448 432.3M 87.94
    Specification S/32 S/16 B/32 B/16 L/32 L/16 H/14
    Number of layers 8 8 12 12 24 24 32
    Patch resolution P×P 32×32 16×16 32×32 16×16 32×32 16×16 14×14
    Hidden size C 512 512 768 768 1024 1024 1280
    Sequence length S 49 196 49 196 49 196 256
    MLP dimension DC 2048 2048 3072 3072 4096 4096 5120
    MLP dimension DS 256 256 384 384 512 512 640
  • Parameter pretrained is added in value [None, "imagenet", "imagenet21k", "imagenet_sam"]. Default is imagenet.

  • Pre-training details

    • We pre-train all models using Adam with β1 = 0.9, β2 = 0.999, and batch size 4 096, using weight decay, and gradient clipping at global norm 1.
    • We use a linear learning rate warmup of 10k steps and linear decay.
    • We pre-train all models at resolution 224.
    • For JFT-300M, we pre-process images by applying the cropping technique from Szegedy et al. [44] in addition to random horizontal flipping.
    • For ImageNet and ImageNet-21k, we employ additional data augmentation and regularization techniques.
    • In particular, we use RandAugment [12], mixup [56], dropout [42], and stochastic depth [19].
    • This set of techniques was inspired by the timm library [52] and Touvron et al. [46].
    • More details on these hyperparameters are provided in Supplementary B.

ResMLP

GMLP