Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use ConvTranspose2d instead of Upsample #100

Open
dengxiongshi opened this issue Apr 30, 2024 · 0 comments
Open

Use ConvTranspose2d instead of Upsample #100

dengxiongshi opened this issue Apr 30, 2024 · 0 comments

Comments

@dengxiongshi
Copy link

@Gumpest 你好,支持把yolov5s-pruning.yaml中的nn.Upsample替换成nn.ConvTranspose2d进行prune吗?我进行替换后,按照给的文档训练一遍模型后,运行pruneEagleEye.py报错:

File "/root/.pycharm_helpers/pydev/pydevd.py", line 1491, in _exec
    pydev_imports.execfile(file, globals, locals)  # execute the script
  File "/root/.pycharm_helpers/pydev/_pydev_imps/_pydev_execfile.py", line 18, in execfile
    exec(compile(contents+"\n", file, 'exec'), glob, loc)
  File "/data/yolov5/pruneEagleEye.py", line 153, in <module>
    rand_prune_and_eval(model, ignore_idx, opt)
  File "/data/yolov5/pruneEagleEye.py", line 65, in rand_prune_and_eval
    compact_model = Model(pruned_yaml, pruning=False).to(device)
  File "/data/yolov5/models/yolo_prune.py", line 325, in __init__
    m.stride = torch.tensor([s / x.shape[-2] for x in forward(torch.zeros(1, ch, s, s))])  # forward
  File "/data/yolov5/models/yolo_prune.py", line 324, in <lambda>
    forward = lambda x: self.forward(x)[0] if isinstance(m, Segment) else self.forward(x)
  File "/data/yolov5/models/yolo_prune.py", line 340, in forward
    return self._forward_once(x, profile, visualize)  # single-scale inference, train
  File "/data/yolov5/models/yolo_prune.py", line 250, in _forward_once
    x = m(x)  # run
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/container.py", line 139, in forward
python-BaseException
    input = module(input)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/yolov5/models/common.py", line 805, in forward
    return self.cv3(torch.cat((self.m(self.cv1(x)), self.cv2(x)), 1))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/data/yolov5/models/common.py", line 90, in forward
    return self.act(self.bn(self.conv(x)))
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1051, in _call_impl
    return forward_call(*input, **kwargs)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 443, in forward
    return self._conv_forward(input, self.weight, self.bias)
  File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/conv.py", line 439, in _conv_forward
    return F.conv2d(input, weight, bias, self.stride,
RuntimeError: Given groups=1, weight of size [64, 56, 1, 1], expected input[1, 128, 32, 32] to have 56 channels, but got 128 channels instead

这是训练的结构:

                 from  n    params  module                                  arguments                     
  0                -1  1      3520  models.common.Conv                      [3, 32, 6, 2, 2, 1, True]     
  1                -1  1     18560  models.common.Conv                      [32, 64, 3, 2, None, 1, True] 
  2                -1  1     18816  models.common.C3_prune                  [64, 64, 64, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
  3                -1  1     73984  models.common.Conv                      [64, 128, 3, 2, None, 1, True]
  4                -1  2    231424  models.common.C3_prune                  [128, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  5                -1  1    295424  models.common.Conv                      [128, 256, 3, 2, None, 1, True]
  6                -1  3   1875456  models.common.C3_prune                  [256, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  7                -1  1   1180672  models.common.Conv                      [256, 512, 3, 2, None, 1, True]
  8                -1  1   1182720  models.common.C3_prune                  [512, 512, 512, 1, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
  9                -1  1    656896  models.common.SPPF_prune                [512, 512, 5, 0.5]            
 10                -1  1    131584  models.common.Conv                      [512, 256, 1, 1, None, 1, True]
 11                -1  1    262400  torch.nn.modules.conv.ConvTranspose2d   [256, 256, 2, 2, 0]           
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    361984  models.common.C3_prune                  [512, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 14                -1  1     33024  models.common.Conv                      [256, 128, 1, 1, None, 1, True]
 15                -1  1     65664  torch.nn.modules.conv.ConvTranspose2d   [128, 128, 2, 2, 0]           
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     90880  models.common.C3_prune                  [256, 128, 128, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 18                -1  1    147712  models.common.Conv                      [128, 128, 3, 2, None, 1, True]
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    296448  models.common.C3_prune                  [256, 256, 256, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 21                -1  1    590336  models.common.Conv                      [256, 256, 3, 2, None, 1, True]
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1   1182720  models.common.C3_prune                  [512, 512, 512, 1, False, 1, [0.5, 0.5], [1.0, 1.0, 1.0]]
 24      [17, 20, 23]  1    229245  models.yolo_prune.Detect                [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [128, 256, 512]]

下面是搜索最优子网结构:

                 from  n    params  module                                  arguments                     
  0                -1  1      2640  models.common.Conv                      [3, 24, 6, 2, 2, 1, True]     
  1                -1  1      6976  models.common.Conv                      [24, 32, 3, 2, None, 1, True] 
  2                -1  1      9440  models.common.C3_prune                  [32, 40, 64, 1, True, 1, [0.5, 0.375], [0.5, 1.0, 1.0]]
  3                -1  1     20272  models.common.Conv                      [40, 56, 3, 2, None, 1, True] 
  4                -1  2    212992  models.common.C3_prune                  [56, 128, 128, 2, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  5                -1  1    230800  models.common.Conv                      [128, 200, 3, 2, None, 1, True]
  6                -1  3   1832448  models.common.C3_prune                  [200, 256, 256, 3, True, 1, [0.5, 0.5], [1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0, 1.0]]
  7                -1  1    571888  models.common.Conv                      [256, 248, 3, 2, None, 1, True]
  8                -1  1    488832  models.common.C3_prune                  [248, 392, 512, 1, True, 1, [0.5, 0.484375], [0.25, 1.0, 1.0]]
  9                -1  1    164638  models.common.SPPF_prune                [392, 512, 5, 0.171875]       
 10                -1  1     94576  models.common.Conv                      [512, 184, 1, 1, None, 1, True]
 11                -1  1    188672  torch.nn.modules.conv.ConvTranspose2d   [184, 256, 2, 2, 0]           
 12           [-1, 6]  1         0  models.common.Concat                    [1]                           
 13                -1  1    298048  models.common.C3_prune                  [512, 104, 256, 1, False, 1, [0.5, 0.34375], [1.0, 1.0, 1.0]]
 14                -1  1      9328  models.common.Conv                      [104, 88, 1, 1, None, 1, True]
 15                -1  1     45184  torch.nn.modules.conv.ConvTranspose2d   [88, 128, 2, 2, 0]            
 16           [-1, 4]  1         0  models.common.Concat                    [1]                           
 17                -1  1     72240  models.common.C3_prune                  [256, 88, 128, 1, False, 1, [0.5, 0.3125], [0.875, 1.0, 1.0]]
 18                -1  1     50816  models.common.Conv                      [88, 64, 3, 2, None, 1, True] 
 19          [-1, 14]  1         0  models.common.Concat                    [1]                           
 20                -1  1    111984  models.common.C3_prune                  [152, 96, 256, 1, False, 1, [0.5, 0.28125], [0.375, 1.0, 1.0]]
 21                -1  1    214768  models.common.Conv                      [96, 248, 3, 2, None, 1, True]
 22          [-1, 10]  1         0  models.common.Concat                    [1]                           
 23                -1  1    610000  models.common.C3_prune                  [432, 304, 512, 1, False, 1, [0.5, 0.40625], [0.40625, 1.0, 1.0]]
 24      [17, 20, 23]  1    125205  models.yolo_prune.Detect                [80, [[10, 13, 16, 30, 33, 23], [30, 61, 62, 45, 59, 119], [116, 90, 156, 198, 373, 326]], [88, 96, 304]]
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant