Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running error on model Swin-ReppointV2 : IndexError: The shape of the mask [3070240] at index 0 does not match the shape of the indexed tensor [383780] at index 0 #195

Open
HUANGL1NJIE opened this issue Oct 23, 2022 · 2 comments

Comments

@HUANGL1NJIE
Copy link

Prerequisite

  1. I have searched related issues but cannot get the expected help.
  2. I have read the FAQ documentation but cannot get the expected help.
  3. The bug has not been fixed in the latest version.

Describe the bug

When I run the command python tools/train.py configs/swin/reppoitsv2_swin_tiny_patch4_window7_mstrain_480_960_giou_gfocal_bifpn_adamw_3x_coco.py , the error traceback as below appeared:

Traceback (most recent call last):
  File "tools/train.py", line 194, in <module>
    main()
  File "tools/train.py", line 183, in main
    train_detector(
  File "d:\swin-transformer-object-detection\mmdet\apis\train.py", line 185, in train_detector
    runner.run(data_loaders, cfg.workflow)
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 127, in run
    epoch_runner(data_loaders[i], **kwargs)
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 50, in train
    self.run_iter(data_batch, train_mode=True, **kwargs)
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\mmcv\runner\epoch_based_runner.py", line 29, in run_iter
    outputs = self.model.train_step(data_batch, self.optimizer,
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\mmcv\parallel\data_parallel.py", line 75, in train_step
    return self.module.train_step(*inputs[0], **kwargs[0])
  File "d:\swin-transformer-object-detection\mmdet\models\detectors\base.py", line 247, in train_step
    losses = self(**data)
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\mmcv\runner\fp16_utils.py", line 128, in new_func
    output = old_func(*new_args, **new_kwargs)
  File "d:\swin-transformer-object-detection\mmdet\models\detectors\base.py", line 181, in forward
    return self.forward_train(img, img_metas, **kwargs)
  File "d:\swin-transformer-object-detection\mmdet\models\detectors\reppoints_v2_detector.py", line 34, in forward_train
    losses = self.bbox_head.loss(
  File "d:\swin-transformer-object-detection\mmdet\models\dense_heads\reppoints_v2_head.py", line 1096, in loss
    loss_sem = self.loss_sem(concat_sem_scores, concat_gt_sem_map, concat_gt_sem_weights, avg_factor=(concat_gt_sem_map > 0).sum())
  File "C:\Users\hlj\.conda\envs\torch_110\lib\site-packages\torch\nn\modules\module.py", line 889, in _call_impl
    result = self.forward(*input, **kwargs)
  File "d:\swin-transformer-object-detection\mmdet\models\losses\focal_loss.py", line 237, in forward
    loss_cls = self.loss_weight * separate_sigmoid_focal_loss(
  File "d:\swin-transformer-object-detection\mmdet\models\losses\focal_loss.py", line 75, in separate_sigmoid_focal_loss
    pos_pred = pred_sigmoid[pos_inds]
IndexError: The shape of the mask [3070240] at index 0 does not match the shape of the indexed tensor [383780] at index 0

As for modifications on the code, I only changed the 'num_classes' in model and 'CLASSES' in script 'coco.py'

However, I run the train.py for training of model Mask R-CNN as below, it can work well without this error.
python tools/train.py configs/swin/mask_rcnn_swin_tiny_patch4_window7_mstrain_480-800_adamw_1x_coco.py

Environment

sys.platform: win32
Python: 3.8.13 (default, Mar 28 2022, 06:59:08) [MSC v.1916 64 bit (AMD64)]
CUDA available: True
GPU 0: NVIDIA GeForce RTX 3090
CUDA_HOME: C:\Program Files\NVIDIA GPU Computing Toolkit\CUDA\v11.1
NVCC: Build cuda_11.1.relgpu_drvr455TC455_06.29190527_0
GCC: n/a
PyTorch: 1.8.1+cu111
PyTorch compiling details: PyTorch built with:

  • C++ Version: 199711
  • MSVC 192829913
  • Intel(R) Math Kernel Library Version 2020.0.2 Product Build 20200624 for Intel(R) 64 architecture applications
  • Intel(R) MKL-DNN v1.7.0 (Git Hash 7aed236906b1f7a05c0917e5257a1af05e9ff683)
  • OpenMP 2019
  • CPU capability usage: AVX2
  • CUDA Runtime 11.1
  • NVCC architecture flags: -gencode;arch=compute_37,code=sm_37;-gencode;arch=compute_50,code=sm_50;-gencode;arch=compute_60,code=sm_60;-gencode;arch=compute_61,code=sm_61;-gencode;arch=compute_70,code=sm_70;-gencode;a
    rch=compute_75,code=sm_75;-gencode;arch=compute_80,code=sm_80;-gencode;arch=compute_86,code=sm_86;-gencode;arch=compute_37,code=compute_37
  • CuDNN 8.0.5
  • Magma 2.5.4
  • Build settings: BLAS_INFO=mkl, BUILD_TYPE=Release, CUDA_VERSION=11.1, CUDNN_VERSION=8.0.5, CXX_COMPILER=C:/w/b/windows/tmp_bin/sccache-cl.exe, CXX_FLAGS=/DWIN32 /D_WINDOWS /GR /EHsc /w /bigobj -DUSE_PTHREADPOOL -ope
    nmp:experimental -DNDEBUG -DUSE_FBGEMM -DUSE_XNNPACK, LAPACK_INFO=mkl, PERF_WITH_AVX=1, PERF_WITH_AVX2=1, PERF_WITH_AVX512=1, TORCH_VERSION=1.8.1, USE_CUDA=ON, USE_CUDNN=ON, USE_EXCEPTION_PTR=1, USE_GFLAGS=OFF, USE_GLOG
    =OFF, USE_MKL=ON, USE_MKLDNN=ON, USE_MPI=OFF, USE_NCCL=OFF, USE_NNPACK=OFF, USE_OPENMP=ON,

TorchVision: 0.9.1+cu111
OpenCV: 4.6.0
MMCV: 1.4.1
MMCV Compiler: MSVC 192930137
MMCV CUDA Compiler: 11.1
MMDetection: 2.11.0+c7b2011

Look forward to the reply. Thanks!

@HUANGL1NJIE HUANGL1NJIE changed the title IndexError: The shape of the mask [3070240] at index 0 does not match the shape of the indexed tensor [383780] at index 0 Running error on model Swin-ReppointV2 : IndexError: The shape of the mask [3070240] at index 0 does not match the shape of the indexed tensor [383780] at index 0 Oct 23, 2022
@geaned
Copy link

geaned commented Oct 31, 2022

Hi, managed to find a fix for this problem absolutely randomly. You can change the code in mmdet/datasets/pipelines/loading.py similarly to this commit and then specify the actual number of classes by adding the num_classes parameter to dict(type='LoadRPDV2Annotations') in your configs.

@HUANGL1NJIE
Copy link
Author

Hi, managed to find a fix for this problem absolutely randomly. You can change the code in mmdet/datasets/pipelines/loading.py similarly to this commit and then specify the actual number of classes by adding the num_classes parameter to dict(type='LoadRPDV2Annotations') in your configs.

Thanks a lot! I have finally run the code successfully by following the solution you provided.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants