Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Has anyone successfully replaced the backbone of Mask2Fromer with dinov2 for both training and inference? #413

Closed
Xinge-Zhou opened this issue May 4, 2024 · 3 comments

Comments

@Xinge-Zhou
Copy link

I've ported the code from the segmentation_m2f directory to the Mask2Former repository, and initialized Mask2Former's backbone using ViTAdapter. While instance segmentation training runs successfully, the inference results are not satisfactory and do not segment correctly.

config:

MODEL:
  BACKBONE:
    NAME: "Dinov2"
  DINOV2:
    # ViT-S (21M params): Patch size 14, embedding dimension 384, 6 heads, MLP FFN
    # ViT-B (86M params): Patch size 14, embedding dimension 768, 12 heads, MLP FFN
    # ViT-L (0.3B params): Patch size 14, embedding dimension 1024, 16 heads, MLP FFN
    # ViT-g (1.1B params): Patch size 14, embedding dimension 1536, 24 heads, SwiGLU FFN
    EMBED_DIM: 1536 
    DEPTH: 12 
    NUM_HEADS: 6 
    IMG_SIZE: 512 
    FFN_TYPE: 'swiglu' 
    PATCH_SIZE: 14 
  WEIGHTS: "dinov2_vitg14_pretrain.pth"
  PIXEL_MEAN: [123.675, 116.280, 103.530]
  PIXEL_STD: [58.395, 57.120, 57.375]
SOLVER:
  IMS_PER_BATCH: 1
  BASE_LR: 0.0001
  MAX_ITER: 20000
TEST:
  EVAL_PERIOD: 5000

backbone

import torch
from .dinov2.models.backbones.vit_adapter import ViTAdapter
from detectron2.modeling import BACKBONE_REGISTRY, Backbone, ShapeSpec


@BACKBONE_REGISTRY.register()
class Dinov2(ViTAdapter, Backbone):
    def __init__(self, cfg, input_shape):
        embed_dim = cfg.MODEL.DINOV2.EMBED_DIM
        num_heads = cfg.MODEL.DINOV2.NUM_HEADS
        ffn_type = cfg.MODEL.DINOV2.FFN_TYPE
        img_size = cfg.MODEL.DINOV2.IMG_SIZE
        patch_size = cfg.MODEL.DINOV2.PATCH_SIZE
        depth = cfg.MODEL.DINOV2.DEPTH

        pretrained = cfg.MODEL.WEIGHTS

        super().__init__(pretrain_size=224,
                         num_heads=num_heads,
                         conv_inplane=64,
                         n_points=4,
                         deform_num_heads=num_heads,
                         init_values=0.0,
                         interaction_indexes=[[0, 5], [6, 11], [12, 17], [18, 23]],
                         with_cffn=True,
                         cffn_ratio=0.25,
                         deform_ratio=0.5,
                         add_vit_feature=True,
                         pretrained=pretrained,
                         use_extra_extractor=True,
                         freeze_vit=False,
                         use_cls=True,
                         with_cp=False,
                         # TIMMVisionTransformer init params
                         img_size=img_size,
                         patch_size=patch_size,
                         in_chans=3,
                         num_classes=200,
                         embed_dim=embed_dim,
                         depth=depth,
                         # num_heads=num_heads, already defined above
                         ffn_type=ffn_type,
                         drop_path_rate=0.1,
                         # deform_ratio=0.5, already defined above
                         )
        self._out_features = cfg.MODEL.DINOV2.OUT_FEATURES
        self._out_feature_strides = {
            "res2": 4,
            "res3": 8,
            "res4": 16,
            "res5": 32,
        }

        num_features = [int(embed_dim * 2 ** i) for i in range(len(self.interaction_indexes))]
        self._out_feature_channels = {
            "res2": num_features[0],
            "res3": num_features[1],
            "res4": num_features[2],
            "res5": num_features[3],
        }

    def forward(self, x):
        """
        Args:
            x: Tensor of shape (N,C,H,W). H, W must be a multiple of ``self.size_divisibility``.
        Returns:
            dict[str->Tensor]: names and the corresponding features
        """
        assert (
                x.dim() == 4
        ), f"SwinTransformer takes an input of shape (N, C, H, W). Got {x.shape} instead!"
        outputs = {}
        x: torch.Tensor = x
        y = super().forward(x)
        for idx, o in enumerate(y):
            outputs[self._out_features[idx]] = o
        return outputs

    def output_shape(self):
        ret_shape = {name: ShapeSpec(channels=self.embed_dim, stride=self._out_feature_strides[name])
                     for name in self._out_features}
        return ret_shape

    @property
    def size_divisibility(self):
        return 32
@Xinge-Zhou
Copy link
Author

Does anyone know how to use Dinov2 as the backbone for Mask2Former and train it on custom datasets? Really appreciate your suggestions.

@roboyul
Copy link

roboyul commented May 9, 2024

Hey @Xinge-Zhou,

There's a similar implementation here: https://github.com/facebookresearch/GuidedDistillation

DinoV2 is used as the backbone for Mask2Former. It's built around distillation so you may need to omit certain portions.

@Xinge-Zhou
Copy link
Author

Hey @Xinge-Zhou,

There's a similar implementation here: https://github.com/facebookresearch/GuidedDistillation

DinoV2 is used as the backbone for Mask2Former. It's built around distillation so you may need to omit certain portions.

This is exactly what I'm looking for, it's very helpful to me, and I really appreciate your assistance.

@qasfb qasfb closed this as completed May 13, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants