Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] Add SwinTransformerV2 (CVPR'22) #2986

Open
wants to merge 9 commits into
base: dev-1.x
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
62 changes: 62 additions & 0 deletions configs/_base_/models/upernet_swinv2.py
@@ -0,0 +1,62 @@
# model settings
norm_cfg = dict(type='SyncBN', requires_grad=True)
backbone_norm_cfg = dict(type='LN', requires_grad=True)
data_preprocessor = dict(
type='SegDataPreProcessor',
mean=[123.675, 116.28, 103.53],
std=[58.395, 57.12, 57.375],
bgr_to_rgb=True,
pad_val=0,
seg_pad_val=255)
model = dict(
type='EncoderDecoder',
data_preprocessor=data_preprocessor,
pretrained=None,
backbone=dict(
type='SwinTransformerV2',
pretrain_img_size=224,
embed_dims=96,
patch_size=4,
window_size=7,
mlp_ratio=4,
depths=[2, 2, 6, 2],
num_heads=[3, 6, 12, 24],
strides=(4, 2, 2, 2),
out_indices=(0, 1, 2, 3),
qkv_bias=True,
patch_norm=True,
drop_rate=0.,
attn_drop_rate=0.,
drop_path_rate=0.1,
use_abs_pos_embed=False,
act_cfg=dict(type='GELU'),
norm_cfg=backbone_norm_cfg,
pretrained_window_sizes=[0, 0, 0, 0]),
decode_head=dict(
type='UPerHead',
in_channels=[96, 192, 384, 768],
in_index=[0, 1, 2, 3],
pool_scales=(1, 2, 3, 6),
channels=512,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=1.0)),
auxiliary_head=dict(
type='FCNHead',
in_channels=384,
in_index=2,
channels=256,
num_convs=1,
concat_input=False,
dropout_ratio=0.1,
num_classes=19,
norm_cfg=norm_cfg,
align_corners=False,
loss_decode=dict(
type='CrossEntropyLoss', use_sigmoid=False, loss_weight=0.4)),
# model training and testing settings
train_cfg=dict(),
test_cfg=dict(mode='whole'))
1 change: 1 addition & 0 deletions configs/swinv2/README.md
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might add more information to README.md.

@@ -0,0 +1 @@
# SwinTransformerV2
Empty file added configs/swinv2/metafile.yaml
Empty file.
@@ -0,0 +1,53 @@
_base_ = [
'../_base_/models/upernet_swinv2.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
data_preprocessor = dict(size=crop_size)
checkpoint_file = './swinv2_base_patch4_window16_256_22k.pth' # noqa
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
pretrain_img_size=256,
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=16,
use_abs_pos_embed=False,
drop_path_rate=0.2,
patch_norm=True),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
@@ -0,0 +1,53 @@
_base_ = [
'../_base_/models/upernet_swinv2.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
data_preprocessor = dict(size=crop_size)
checkpoint_file = './swinv2_base_patch4_window24_384_22k.pth' # noqa
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
pretrain_img_size=384,
embed_dims=128,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=24,
use_abs_pos_embed=False,
drop_path_rate=0.2,
patch_norm=True),
decode_head=dict(in_channels=[128, 256, 512, 1024], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
@@ -0,0 +1,53 @@
_base_ = [
'../_base_/models/upernet_swinv2.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
data_preprocessor = dict(size=crop_size)
checkpoint_file = './swinv2_large_patch4_window16_256_22k.pth' # noqa
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
pretrain_img_size=256,
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=16,
use_abs_pos_embed=False,
drop_path_rate=0.2,
patch_norm=True),
decode_head=dict(in_channels=[192, 384, 768, 1536], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
@@ -0,0 +1,53 @@
_base_ = [
'../_base_/models/upernet_swinv2.py', '../_base_/datasets/ade20k.py',
'../_base_/default_runtime.py', '../_base_/schedules/schedule_160k.py'
]
crop_size = (640, 640)
data_preprocessor = dict(size=crop_size)
checkpoint_file = './swinv2_large_patch4_window24_384_22k.pth' # noqa
model = dict(
data_preprocessor=data_preprocessor,
backbone=dict(
init_cfg=dict(type='Pretrained', checkpoint=checkpoint_file),
pretrain_img_size=384,
embed_dims=192,
depths=[2, 2, 18, 2],
num_heads=[4, 8, 16, 32],
window_size=16,
use_abs_pos_embed=False,
drop_path_rate=0.2,
patch_norm=True),
decode_head=dict(in_channels=[192, 384, 768, 1536], num_classes=150),
auxiliary_head=dict(in_channels=512, num_classes=150))

# AdamW optimizer, no weight decay for position embedding & layer norm
# in backbone
optim_wrapper = dict(
_delete_=True,
type='OptimWrapper',
optimizer=dict(
type='AdamW', lr=0.00006, betas=(0.9, 0.999), weight_decay=0.01),
paramwise_cfg=dict(
custom_keys={
'absolute_pos_embed': dict(decay_mult=0.),
'relative_position_bias_table': dict(decay_mult=0.),
'norm': dict(decay_mult=0.)
}))

param_scheduler = [
dict(
type='LinearLR', start_factor=1e-6, by_epoch=False, begin=0, end=1500),
dict(
type='PolyLR',
eta_min=0.0,
power=1.0,
begin=1500,
end=160000,
by_epoch=False,
)
]

# By default, models are trained on 8 GPUs with 2 images per GPU
train_dataloader = dict(batch_size=2)
val_dataloader = dict(batch_size=1)
test_dataloader = val_dataloader
3 changes: 2 additions & 1 deletion mmseg/models/backbones/__init__.py
Expand Up @@ -19,6 +19,7 @@
from .resnext import ResNeXt
from .stdc import STDCContextPathNet, STDCNet
from .swin import SwinTransformer
from .swinv2 import SwinTransformerV2
from .timm_backbone import TIMMBackbone
from .twins import PCPVT, SVT
from .unet import UNet
Expand All @@ -27,7 +28,7 @@
__all__ = [
'ResNet', 'ResNetV1c', 'ResNetV1d', 'ResNeXt', 'HRNet', 'FastSCNN',
'ResNeSt', 'MobileNetV2', 'UNet', 'CGNet', 'MobileNetV3',
'VisionTransformer', 'SwinTransformer', 'MixVisionTransformer',
'VisionTransformer', 'SwinTransformer', 'SwinTransformerV2', 'MixVisionTransformer',
'BiSeNetV1', 'BiSeNetV2', 'ICNet', 'TIMMBackbone', 'ERFNet', 'PCPVT',
'SVT', 'STDCNet', 'STDCContextPathNet', 'BEiT', 'MAE', 'PIDNet', 'MSCAN',
'DDRNet'
Expand Down