Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

RuntimeError torch.cat(): expected a non-empty list of Tensors #25

Open
advancing-panda opened this issue Mar 12, 2024 · 2 comments
Open

Comments

@advancing-panda
Copy link

advancing-panda commented Mar 12, 2024

When I use the command python tools/train.py configs/tr3d/tr3d_s3dis-3d-5class.py to train, the following error appears in file tr3d_head.py 189.

RuntimeError
torch.cat(): expected a non-empty list of Tensors
File "/mmdetection3d/mmdet3d/models/dense_heads/tr3d_head.py", line 189, in _loss
bbox_loss=torch.mean(torch.cat(bbox_losses)),
File "/mmdetection3d/mmdet3d/models/dense_heads/tr3d_head.py", line 195, in forward_train
gt_bboxes, gt_labels, img_metas)
File "/mmdetection3d/mmdet3d/models/detectors/mink_single_stage.py", line 88, in forward_train
img_metas)
File "/mmdetection3d/mmdet3d/models/detectors/base.py", line 60, in forward
return self.forward_train(**kwargs)
File "/mmdetection3d/mmdet3d/apis/train.py", line 319, in train_detector
runner.run(data_loaders, cfg.workflow)
File "/mmdetection3d/mmdet3d/apis/train.py", line 351, in train_model
meta=meta)
File "/mmdetection3d/tools/train.py", line 259, in main
meta=meta)
File "/mmdetection3d/tools/train.py", line 263, in
main()
RuntimeError: torch.cat(): expected a non-empty list of Tensors

After debugging, it was found that bbox_losses and gt_bboxes was empty, causing this error to occur. How should it be corrected?

@filaPro
Copy link
Contributor

filaPro commented Mar 12, 2024

Hi @advancing-panda ,
Did you make any modification in the code?

Looks like this error can happen with the very low probability if all 16 scenes per batch have no ground truth objects. In this case you can check if len(bbox_losses) == 0 then just return zero tensor here.

@advancing-panda
Copy link
Author

advancing-panda commented Mar 15, 2024

Thank you for your reply! Referring to your reply, I changed the code to the following line and the program ran normally.

bbox_loss=torch.mean(torch.cat(bbox_losses)) if bbox_losses != [] else bbox_losses

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants