Skip to content

Latest commit

 

History

History
130 lines (101 loc) · 10.6 KB

Classification_Models_Guide.md

File metadata and controls

130 lines (101 loc) · 10.6 KB

How to use ViTs in PASSL

PASSL provides developers with a number of implementations of Transformer classification models for the vision domain, each of which can be invoked through PASSL's configuration files so that users can quickly implement research experiments, and provides model pre-training weights that can be used to fine-tune their own datasets

Included Model

Weights Download

Arch Weight Top-1 Acc Top-5 Acc Crop ratio # Params
cait_s24_224 pretrain 1k 83.45 96.57 1.0 46.8M
cait_xs24_384 pretrain 1k 84.06 96.89 1.0 26.5M
cait_s24_384 pretrain 1k 85.05 97.34 1.0 46.8M
cait_s36_384 pretrain 1k 85.45 97.48 1.0 68.1M
cait_m36_384 pretrain 1k 86.06 97.73 1.0 270.7M
cait_m48_448 pretrain 1k 86.49 97.75 1.0 355.8M
t2t_vit_14 pretrain 1k 81.50 95.67 0.9 21.5M
t2t_vit_19 pretrain 1k 81.93 95.74 0.9 39.1M
t2t_vit_24 pretrain 1k 82.28 95.89 0.9 64.0M
t2t_vit_t_14 pretrain 1k 81.69 95.85 0.9 21.5M
t2t_vit_t_19 pretrain 1k 82.44 96.08 0.9 39.1M
t2t_vit_t_24 pretrain 1k 82.55 96.07 0.9 64.0M
cvt_13_224 pretrain 1k 81.59 95.67 0.875 20.0M
cvt_13_384 ft 22k to 1k 82.90 96.92 1.0 20.0M
cvt_21_224 pretrain 1k 82.46 96.00 0.875 31.6M
cvt_21_384 ft 22k to 1k 84.63 97.54 1.0 31.6M
cvt_w24_384 ft 22k to 1k 87.39 98.37 1.0 277.3M
beit_base_p16_224 ft 22k to 1k 85.21 97.66 0.9 87M
beit_base_p16_384 ft 22k to 1k 86.81 98.14 1.0 87M
beit_large_p16_224 ft 22k to 1k 87.48 98.30 0.9 304M
beit_large_p16_384 ft 22k to 1k 88.40 98.60 1.0 304M
beit_large_p16_512 ft 22k to 1k 88.60 98.66 1.0 304M
mlp_mixer_b16_224 pretrain 1k 76.60 92.23 0.875 60.0M
mlp_mixer_l16_224 pretrain 1k 72.06 87.67 0.875 208.2M
xcit_nano_12_p8_224 pretrain 1k 73.90 92.13 1.0 3.05M
xcit_nano_12_p8_224_dist pretrain 1k 77.28 93.25 1.0 3.05M
xcit_tiny_12_p8_224 pretrain 1k 79.68 95.04 1.0 6.71M
xcit_tiny_24_p8_224 pretrain 1k 81.87 95.97 1.0 12.11M
xcit_small_12_p8_224 pretrain 1k 83.36 96.51 1.0 26.21M
xcit_small_24_p8_224 pretrain 1k 83.82 96.65 1.0 47.63M
xcit_medium_24_p8_224 pretrain 1k 83.73 96.39 1.0 84.32M
xcit_large_24_p8_224 pretrain 1k 84.42 96.65 1.0 188.93M
xcit_nano_12_p16_224 pretrain 1k 70.01 89.82 1.0 3.05M
xcit_tiny_12_p16_224 pretrain 1k 77.15 93.72 1.0 6.72M
xcit_tiny_24_p16_224 pretrain 1k 79.42 94.86 1.0 12.12M
xcit_small_12_p16_224 pretrain 1k 81.89 95.83 1.0 26.25M
xcit_small_24_p16_224 pretrain 1k 82.51 95.97 1.0 47.67M
xcit_medium_24_p16_224 pretrain 1k 82.67 95.91 1.0 84.40M
xcit_large_24_p16_224 pretrain 1k 82.89 95.89 1.0 189.10M

The above metrics were tested on the ImageNet 2012 dataset.

Note:pretrain 1k means that the model is trained directly on ImageNet1k, ft 22k in 1k means that the model is trained on ImageNet22k and then fine-tuned on ImageNet1K

Usage

Please install the necessary packages first to ensure the code can run, see INSTALL.md

You can run the following code in the ./PASSL directory and you can change the cfg_file to select the model you want

You can download the appropriate weights for the model to load the pre-training weights

import paddle
from passl.modeling.backbones import build_backbone
from passl.modeling.heads import build_head
from passl.utils.config import get_config


class CreatModel(paddle.nn.Layer):
    def __init__(self, cfg_file):
        super().__init__()
        cfg = get_config(cfg_file)
        self.backbone = build_backbone(cfg.model.architecture)
        self.head = build_head(cfg.model.head)

    def forward(self, x):

        x = self.backbone(x)
        x = self.head(x)
        return x


cfg_file = "configs/cvt/cvt_13_224.yaml"
model = CreatModel(cfg_file)

model_state_dict = paddle.load('cvt_13_224.pdparams')
model.set_dict(model_state_dict)

If you need to fine tune with the model,ou can modify the config file, such as changing the number of categories

# configs/cvt/cvt_13_224.yaml
...

model:
  name: CvTWrapper
  architecture:
      name: CvT
      embed_dim: [64, 192, 384]
      depth: [1, 2, 10]
      num_heads: [1, 3, 6]
  head:
    name: CvTClsHead
    num_classes: 10   # Modify the number of categories to match your taxonomy data set
    in_channels: 384
    
...

Coming Soon

model train

model validate

Contact

If you have any questions, please create an issue on our Github.