Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

PViG用于目标检测 #252

Open
yuan0038 opened this issue Apr 14, 2024 · 2 comments
Open

PViG用于目标检测 #252

yuan0038 opened this issue Apr 14, 2024 · 2 comments

Comments

@yuan0038
Copy link

yuan0038 commented Apr 14, 2024

韩凯大佬,想请教一下您 PViG中的COCO目标检测配置:
个人环境及配置:

①库为mmdet,设备为4×A6000(48G)
②目标检测框架:官方提供的mask_RCNN, 训练时只替换了backbone:PViG_S
③训练方式为1×schedule
④每张卡跑2张图片(即总batch为8),img scale为(1333,800)
然后就出现了如下问题:

CUDA error: an illegal memory access was encountered

查询过网上相关资料,原因可能是显存不够,然后我尝试

  1. img scale 的分辨率调成一半,backbone 仍为PViG_S(参数量为45.8M,与论文相同,模型应该没搭错)
  2. img scale为(1333,800),换了个其他backbone (参数量也为45.8M)
  3. img scale为(1333,800),换成PViG_Ti(参数量为29,3M)
    1,2能正常跑,3报错如下:
    x = self.grapher(x)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 189, in forward
    x = self.graph_conv(x, relative_pos)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 138, in forward
    x = super(DyGraphConv2d, self).forward(x, edge_index, y)
  File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 111, in forward
    return self.gconv(x, edge_index, y)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/guest/workplace/lzy/GraphConvNet/detection/models/gcn_lib/torch_vertex.py", line 34, in forward
    return self.nn(x)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/container.py", line 139, in forward
    input = module(input)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/module.py", line 1130, in _call_impl
    return forward_call(*input, **kwargs)
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/batchnorm.py", line 757, in forward
    world_size,
  File "/home/guest/miniconda3/envs/lzy_mmdet/lib/python3.7/site-packages/torch/nn/modules/_functions.py", line 80, in forward
    count_all = count_all[mask]
RuntimeError: CUDA error: an illegal memory access was encounteredFile "

⭐️所以我的问题是,按照标准跑法(虽然batch小了,但是跟标准跑法一样都是平均每张卡都是2张图片),为什么A6000会跑不动PViG_S,恳请大佬指点🌹

@iamhankai
Copy link
Member

是不是dilation参数设置不对,太大了,需要改小点

@yuan0038
Copy link
Author

目标检测的backbone 如pvig_s,打印出来的dilation是[1, 1, 1, 1, 2, 2, 2, 2, 3, 3, 3, 3],模型应该是没问题的。
而且我甚至把dilation全改成1试了试,也还是出现上面的问题,就很神奇。😂

class Pyramid_ViG(torch.nn.Module):
    def __init__(self, k,gconv,channels,blocks,n_classes,act,norm,bias,epsilon,use_stochastic,dropout,drop_path,
                 pretrained=None,out_indices=None):
        super().__init__()

        self.pretrained = pretrained
        self.out_indices = out_indices

        self.n_blocks = sum(blocks)
        reduce_ratios = [4, 2, 1, 1]
        dpr = [x.item() for x in torch.linspace(0, drop_path, self.n_blocks)]  # stochastic depth decay rule
        num_knn = [int(x.item()) for x in torch.linspace(k, k, self.n_blocks)]  # number of knn's k
        print(num_knn)
        max_dilation = 49 // max(num_knn)

        self.stem = Stem(out_dim=channels[0], act=act)
        self.pos_embed = nn.Parameter(torch.zeros(1, channels[0], 224 // 4, 224 // 4))
        HW = 224 // 4 * 224 // 4

        self.backbone = nn.ModuleList([])


        #dilation=[min(idx // 4 + 1, max_dilation) for idx in range(sum(blocks))]
        dilation = [1 for i in range(sum(blocks))]
        idx = 0
        for i in range(len(blocks)):
            if i > 0:
                self.backbone.append(Downsample(channels[i - 1], channels[i]))
                HW = HW // 4
            for j in range(blocks[i]):
                self.backbone += [
                    Seq(
                        *[Block(channels[i],num_knn[idx], dilation[idx], gconv, act, norm,
                                bias, use_stochastic, epsilon, reduce_ratios[i],n=HW, drop_path=dpr[idx],
                                relative_pos=True)])
                    ]

                idx += 1
        self.backbone = Seq(*self.backbone)
        print("\u2b50 dilation:",dilation)
        self.init_weights()
        self = torch.nn.SyncBatchNorm.convert_sync_batchnorm(self)

    @torch.no_grad()
    def train(self, mode=True):
        super().train(mode)
        for m in self.modules():
            if isinstance(m, nn.BatchNorm2d):
                m.eval()
    def init_weights(self):
        logger = get_root_logger()
        print("Pretrained weights being loaded")
        logger.warn('Pretrained weights being loaded')
        ckpt_path = self.pretrained
        ckpt = _load_checkpoint(
            ckpt_path, logger=logger, map_location='cpu')
        print("ckpt keys: ", ckpt.keys())
        if 'state_dict' in ckpt:
            _state_dict = ckpt['state_dict']
        elif 'model' in ckpt:
            _state_dict = ckpt['model']
        else:
            _state_dict = ckpt
        state_dict = _state_dict
        new_state_dict={}
        for k,v in state_dict.items():
            new_k  = k.replace(".grapher",'')
            new_state_dict[new_k]=v
        print(new_state_dict.keys())
        missing_keys, unexpected_keys = \
            self.load_state_dict(new_state_dict, False)
        print("missing_keys: ", missing_keys)
        print("unexpected_keys: ", unexpected_keys)

    def interpolate_pos_encoding(self, x):
        w, h = x.shape[2], x.shape[3]
        p_w, p_h = self.pos_embed.shape[2], self.pos_embed.shape[3]

        if w * h == p_w * p_h and w == h:
            return self.pos_embed

        w0 = w
        h0 = h
        # we add a small number to avoid floating point error in the interpolation
        # see discussion at https://github.com/facebookresearch/dino/issues/8
        w0, h0 = w0 + 0.1, h0 + 0.1
        patch_pos_embed = nn.functional.interpolate(
            self.pos_embed,
            scale_factor=(w0 / p_w, h0 / p_h),
            mode='bicubic',
        )
        assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1]
        return patch_pos_embed

    def forward(self, inputs):
        outs=[]
        B, C, H, W = inputs.shape

        x = self.stem(inputs)

        x = x + self.interpolate_pos_encoding(x)

        for i in range(len(self.backbone)):

            x = self.backbone[i](x)
            if i in self.out_indices:
                outs.append(x)

        return outs
  def pvig_s_feat(pretrained=True,**kwargs):
        model = Pyramid_ViG( k=9,  # neighbor num (default:9)
            gconv='mr',  # graph conv layer {edge, mr}
            channels=[80, 160, 400, 640],  # number of channels of deep features
            blocks=[2, 2, 6, 2],  # number of basic blocks in the backbone
            n_classes=1000,  # Dimension of out_channels
            act='gelu',  # activation layer {relu, prelu, leakyrelu, gelu, hswish}
            norm='batch',  # batch or instance normalization {batch, instance}
            bias=True,  # bias of conv layer True or False
            epsilon=0.2,  # stochastic epsilon for gcn
            use_stochastic=False,  # stochastic for gcn, True or False
            dropout=0.0,  # dropout rate
            drop_path=0.0,
            pretrained='../ckpt/pvig_s_82.1.pth.tar',
            out_indices=[1,4,11,14])

        model.default_cfg = _cfg()
        return model

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants