Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader gives "Broken pipe" error on Linux platform #46802

Open
yxchng opened this issue Oct 24, 2020 · 6 comments
Open

DataLoader gives "Broken pipe" error on Linux platform #46802

yxchng opened this issue Oct 24, 2020 · 6 comments
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@yxchng
Copy link

yxchng commented Oct 24, 2020

馃悰 Bug

Pytorch's DataLoader gives "Broken pipe" error on Linux platform (not Windows). Using num_workers=0 suppresses the error but that is not a satisfying solution (more of a workaround) because it will largely reduce the efficiency of the code. If it is not a bug, hopefully a guide on how to correct the following code can be given.

To Reproduce

Steps to reproduce the behavior:

import os
import sys
import time
import glob
import numpy as np
import torch
import utils
import logging
import argparse
import torch.nn as nn
import torch.utils
import torch.nn.functional as F
import torchvision.datasets as dset
import torch.backends.cudnn as cudnn
import torchvision.transforms as transforms

def _data_transforms_cifar10():
    CIFAR_MEAN = [0.49139968, 0.48215827, 0.44653124]
    CIFAR_STD = [0.24703233, 0.24348505, 0.26158768]

    train_transform = transforms.Compose([
        transforms.RandomCrop(32, padding=4),
        transforms.RandomHorizontalFlip(),
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
    ])

    valid_transform = transforms.Compose([
        transforms.ToTensor(),
        transforms.Normalize(CIFAR_MEAN, CIFAR_STD),
        ])
    return train_transform, valid_transform

def main():

    train_portion = 0.5
    train_transform, valid_transform = _data_transforms_cifar10()

    train_data = dset.CIFAR10(root='.', train=True, download=True, transform=train_transform)
    num_train = len(train_data)

    indices = list(range(num_train))
    split = int(np.floor(train_portion * num_train))


    train_queue = torch.utils.data.DataLoader(
        train_data, batch_size=64,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[:split]),
        pin_memory=True, num_workers=2)

    valid_queue = torch.utils.data.DataLoader(
        train_data, batch_size=64,
        sampler=torch.utils.data.sampler.SubsetRandomSampler(indices[split:num_train]),
        pin_memory=True, num_workers=2)

    train(train_queue, valid_queue)

def train(train_queue, valid_queue):
    for step, (input, target) in enumerate(train_queue):
        input_search, target_search = next(iter(valid_queue))

if __name__ == '__main__':
    main()

Stack trace:

Traceback (most recent call last):
  File "/usr/local/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
Traceback (most recent call last):
  File "/usr/local/lib/python3.6/multiprocessing/queues.py", line 240, in _feed
    send_bytes(obj)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 200, in send_bytes
    self._send_bytes(m[offset:offset + size])
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 404, in _send_bytes
    self._send(header + buf)
  File "/usr/local/lib/python3.6/multiprocessing/connection.py", line 368, in _send
    n = write(self._handle, buf)
BrokenPipeError: [Errno 32] Broken pipe

Expected behavior

Runs without any error.

Environment

Please copy and paste the output from our
environment collection script
(or fill out the checklist below manually).

You can get the script and run it with:

wget https://raw.githubusercontent.com/pytorch/pytorch/master/torch/utils/collect_env.py
# For security purposes, please check the contents of collect_env.py before running it.
python collect_env.py
PyTorch version: 1.6.0+cu101
Is debug build: False
CUDA used to build PyTorch: 10.1
ROCM used to build PyTorch: N/A

OS: Debian GNU/Linux 9.9 (stretch) (x86_64)
GCC version: (Debian 6.3.0-18+deb9u1) 6.3.0 20170516
Clang version: Could not collect
CMake version: Could not collect

Python version: 3.6 (64-bit runtime)
Is CUDA available: True
CUDA runtime version: 10.1.243
GPU models and configuration:
GPU 0: Tesla V100-SXM2-32GB
GPU 1: Tesla V100-SXM2-32GB
GPU 2: Tesla V100-SXM2-32GB
GPU 3: Tesla V100-SXM2-32GB
GPU 4: Tesla V100-SXM2-32GB
GPU 5: Tesla V100-SXM2-32GB
GPU 6: Tesla V100-SXM2-32GB
GPU 7: Tesla V100-SXM2-32GB

Nvidia driver version: 418.87.00
cuDNN version: /usr/local/cuda-10.1/targets/x86_64-linux/lib/libcudnn.so.7
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.1
[pip3] torch==1.6.0+cu101
[pip3] torchvision==0.7.0+cu101
[conda] Could not collect

cc @ssnl @VitalyFedyunin @ejguan

@izdeby izdeby added module: dataloader Related to torch.utils.data.DataLoader and Sampler triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 26, 2020
@VitalyFedyunin
Copy link
Contributor

can you please comment what you are trying to archive with this line input_search, target_search = next(iter(valid_queue)) as it looks like faulty to me as it creates new iterator on every outer loop and causing validation dataset workers to die sooner.

@VitalyFedyunin VitalyFedyunin added the module: multiprocessing Related to torch.multiprocessing label Oct 26, 2020
@yxchng
Copy link
Author

yxchng commented Nov 2, 2020

@VitalyFedyunin I am just trying to get the next sample from validation queue. What should be the correct way of doing this then?

@yxchng
Copy link
Author

yxchng commented Nov 12, 2020

@VitalyFedyunin also this next(iter(valid_queue)) is suggested here #1917

@zylprivate
Copy link

I meet same error,but when I run same code in cpu ,the error doesn't appear.

@ejguan
Copy link
Contributor

ejguan commented Jan 4, 2021

@yxchng
Can you try to create the iterator outside of for loop, and call next within it.

valid_iter = iter(valid_queue)
for ... in ...:
    input_search, target_search = next(valid_iter)

@cdzhan
Copy link
Contributor

cdzhan commented Mar 7, 2022

I encountered the same "BrokenPipeError" occasionally when I broke the loop in middle of the iterations. I was using Pytorch 1.6.

                if hasattr(self, '_pin_memory_thread'):
                    # Use hasattr in case error happens before we set the attribute.
                    self._pin_memory_thread_done_event.set()
                    # Send something to pin_memory_thread in case it is waiting
                    # so that it can wake up and check `pin_memory_thread_done_event`
                    self._worker_result_queue.put((None, None))
                    self._pin_memory_thread.join()
                    self._worker_result_queue.cancel_join_thread()
                    self._worker_result_queue.close()    # may close too early

After debugging the code, I found worker_result_queue may close reader of itself before the writing of pipe finish. @VitalyFedyunin

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: dataloader Related to torch.utils.data.DataLoader and Sampler module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

6 participants