Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Training abruptly crashes on single GPU #318

Open
pranavsinghps1 opened this issue Aug 17, 2023 · 5 comments
Open

Training abruptly crashes on single GPU #318

pranavsinghps1 opened this issue Aug 17, 2023 · 5 comments
Labels
bug Something isn't working

Comments

@pranavsinghps1
Copy link

While working with the knee dataset on a VarNet from Pytorch-lighting's library and using the FastMriDataModule data-loaders, I observed that the training is unstable and crashes fairly often. I tried looking for similar issues within this repo but couldn't find any. I looked up PyTorch's forum to check for the same and observed such an issue is encountered when the data loader doesn't work well with multiprocessing link (pytorch/pytorch#8976) -- they recommended using workers=0 which did stabilize my training for some time but after a while it crashes as well.

  • Training on single GPU with:
backend = "gpu"
num_gpus = 1
batch_size = 8

using the FastMriDataModule on the single-coil Knee dataset. Reproduced on single V100 and RTX8000 GPU.

lightning    1.8.6
torch          2.0.1
  • The Entire Traceback is as follows:

File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/ext3/miniconda3/lib/python3.10/multiprocessing/queues.py", line 122, in get
return _ForkingPickler.loads(res)
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/multiprocessing/reductions.py", line 307, in rebuild_storage_fd
fd = df.detach()
File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 57, in detach
with _resource_sharer.get_connection(self._id) as conn:
File "/ext3/miniconda3/lib/python3.10/multiprocessing/resource_sharer.py", line 86, in get_connection
c = Client(address, authkey=process.current_process().authkey)
File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 508, in Client
answer_challenge(c, authkey)
File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 752, in answer_challenge
message = connection.recv_bytes(256) # reject large message
File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 216, in recv_bytes
buf = self._recv_bytes(maxlength)
File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 414, in _recv_bytes
buf = self._recv(4)
File "/ext3/miniconda3/lib/python3.10/multiprocessing/connection.py", line 379, in _recv
chunk = read(handle, remaining)
ConnectionResetError: [Errno 104] Connection reset by peer

During handling of the above exception, another exception occurred:

Traceback (most recent call last):
File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 192, in
run_cli()
File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 188, in run_cli
cli_main(args)
File "/scratch/ps4364/fmri2020/varnet_l1_2/unet_knee_sc.py", line 72, in cli_main
trainer.fit(model, datamodule=data_module)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit
call._call_and_handle_interrupt(
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt
return trainer_fn(*args, **kwargs)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl
self._run(model, ckpt_path=self.ckpt_path)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run
results = self._run_stage()
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage
self._run_train()
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/trainer.py", line 1200, in _run_train
self.fit_loop.run()
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/fit_loop.py", line 267, in advance
self._outputs = self.epoch_loop.run(self._data_fetcher)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/loop.py", line 199, in run
self.advance(*args, **kwargs)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/loops/epoch/training_epoch_loop.py", line 188, in advance
batch = next(data_fetcher)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 184, in next
return self.fetching_function()
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 265, in fetching_function
self._fetch_next_batch(self.dataloader_iter)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/utilities/fetching.py", line 280, in _fetch_next_batch
batch = next(iterator)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 568, in next
return self.request_next_batch(self.loader_iters)
File "/ext3/miniconda3/lib/python3.10/site-packages/pytorch_lightning/trainer/supporters.py", line 580, in request_next_batch
return apply_to_collection(loader_iters, Iterator, next)
File "/ext3/miniconda3/lib/python3.10/site-packages/lightning_utilities/core/apply_func.py", line 51, in apply_to_collection
return function(data, *args, **kwargs)
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 633, in next
data = self._next_data()
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1328, in _next_data
idx, data = self._get_data()
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1294, in _get_data
success, data = self._try_get_data()
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/dataloader.py", line 1132, in _try_get_data
data = self._data_queue.get(timeout=timeout)
File "/ext3/miniconda3/lib/python3.10/site-packages/torch/utils/data/_utils/signal_handling.py", line 66, in handler
_error_if_any_worker_fails()
RuntimeError: DataLoader worker (pid 3489789) is killed by signal: Killed.

@pranavsinghps1 pranavsinghps1 added the bug Something isn't working label Aug 17, 2023
@mmuckley
Copy link
Contributor

Hello @pranavsinghps1, this is a confusing error. I don't see a single line in the trace that mentions fastMRI. Are you sure there isn't an issue with your install?

Also, we don't actually test VarNet with the single coil data - it's really meant for multicoil with a batch size of 1. Is there a reference that shows VarNet works for single coil that you're trying to reproduce?

@pranavsinghps1
Copy link
Author

I see, Thank you for your prompt response.
I will try to realign with the requirements mentioned here (https://github.com/facebookresearch/fastMRI/blob/main/setup.cfg)

As for the use of VarNet for single coil reconstruction -- I did see that in [1], it is mentioned that VarNet is exclusively used for multicoil reconstruction while U-Net for both -- is there a rationale for this? I was trying to figure out the same. For my VarNet I have removed the sensitivity net and just using the Vanila VarNet with ResNet 18 backbone.

[1] Sriram, Anuroop, et al. "End-to-end variational networks for accelerated MRI reconstruction." Medical Image Computing and Computer Assisted Intervention–MICCAI 2020: 23rd International Conference, Lima, Peru, October 4–8, 2020, Proceedings, Part II 23. Springer International Publishing, 2020.

@mmuckley
Copy link
Contributor

mmuckley commented Aug 18, 2023

Hello @pranavsinghps1, the main innovation of that paper is the end-to-end aspect where the model estimates both the sensitivity maps and the final image. In non-E2E VarNets, the sensitivity maps are precomputed via another method (such as ESPiRIT). Those methods are not end-to-end.

However, in the single-coil case there are no sensitivities, so you just have a regular VarNet.

We never prioritized the development of a single-coil VarNet because in the real world, all MRI scanners are multicoil. There are enormous benefits of multicoil over single coil in terms of SNR and image quality. The single-coil data is only a sort of toy setting for interested people initially getting into the area, but only works done on the multi-coil data are likely to have any impact on real-world scanners.

@pranavsinghps1
Copy link
Author

Thank you @mmuckley for the detailed information on this: I had one question: why multi-coil is trained with a batch size of 1 ?

Update on the issue: rewriting the dataloaders using SliceDataset solved the issue.

@mmuckley
Copy link
Contributor

mmuckley commented Aug 25, 2023

Hello @pranavsinghps1, the main reason is that many of the multicoil volumes have different matrix sizes for the data. With the VarNet we need to do data consistency on the raw data, so there is no way to do simple batching. In the end we made the VarNet large enough that it used all of 1 GPU's memory, and so we found that batch size of 1, with a large model, was the most effective training strategy.

As for the Issue, could you post more details of your solution? If there is no issue with the core repository, please close the issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

2 participants