Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

DataLoader num_workers > 0 causes CPU memory from parent process to be replicated in all worker processes #13246

Open
bfreskura opened this issue Oct 29, 2018 · 141 comments
Labels
high priority module: dataloader Related to torch.utils.data.DataLoader and Sampler module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: molly-guard Features which help prevent users from committing common mistakes module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@bfreskura
Copy link

bfreskura commented Oct 29, 2018

Editor note: There is a known workaround further down on this issue, which is to NOT use Python lists, but instead using something else, e.g., torch.tensor directly. See #13246 (comment) . You can use a numpy array, but it only fixes the issue for the fork start method. See #13246 (comment) for more details

🐛 Bug

CPU memory will leak if the DataLoader num_workers > 0.

To Reproduce

Run the following snippet:

from torch.utils.data import Dataset, DataLoader
from PIL import Image
from torchvision import transforms
import os

class DataIter(Dataset):
    def __init__(self):
        path = "path/to/data"
        self.data = []

        for cls in os.listdir(path):
            for img in os.listdir(os.path.join(path, cls)):
                self.data.append(os.path.join(path, cls, img))

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        with Image.open(self.data[idx]) as img:
            img = img.convert('RGB')
            return transforms.functional.to_tensor(img)


train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False,
                          num_workers=18)

for i, item in enumerate(train_loader):
    if i % 200 == 0:
        print(i)

Expected behavior

CPU memory will gradually start increasing, eventually filling up the whole RAM. E.g., the process starts with around 15GB and fills up the whole 128GB available on the system.
When the num_workers=0, RAM usage is constant.

Environment

PyTorch version: 1.0.0.dev20181028
Is debug build: No
CUDA used to build PyTorch: 9.0.176

OS: Ubuntu 16.04.4 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.0.176
GPU models and configuration: 
GPU 0: GeForce GTX 1080 Ti
GPU 1: GeForce GTX 1080 Ti
GPU 2: GeForce GTX 1080 Ti

Nvidia driver version: 390.67
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.4

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect

PIL.__version__
'5.3.0'

Additional info

There are around 24 million images in the dataset and all image paths are loaded into a single list as presented in the above code snippet.

I have also tried multiple Pytorch (0.4.0 and 0.4.1) versions and the effect is the same.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @anjali411 @ssnl @VitalyFedyunin @ejguan

@ssnl
Copy link
Collaborator

ssnl commented Oct 29, 2018

Do you see memory usage increasing when iterating, or before you even start to iterate?

@bfreskura
Copy link
Author

@ssnl During the iteration only.

@ezyang
Copy link
Contributor

ezyang commented Oct 29, 2018

When we fix #13243 we should check if this one gets fixed too.

@samgd
Copy link

samgd commented Oct 31, 2018

I've been experiencing something similar where memory usage continuously climbs until a OOM is triggered when using a batch_sampler with num_workers>0.

To Reproduce

import math

from torch.utils.data import DataLoader


class Sampler:
    def __init__(self, n=100000, batch_size=32):
        self.n = n
        self.batch_size = batch_size

    def __len__(self):
        return math.ceil(float(self.n)/self.batch_size)

    def __iter__(self):
        batch = []
        for i in range(self.n):
            batch.append(i)
            if len(batch) == self.batch_size:
                yield batch
                batch = []
        if batch:
            yield batch

            
N = 100000000
train_data = list(range(N))

            
def ok():
    train_sampler = Sampler(len(train_data))
    train_loader = DataLoader(train_data,
                              num_workers=0,
                              batch_sampler=train_sampler)
    
    for i, item in enumerate(train_loader):
        if i % 10000 == 0:
            print(i)
            
            
def leaky():
    train_sampler = Sampler(len(train_data))
    train_loader = DataLoader(train_data,
                              num_workers=8,
                              batch_sampler=train_sampler)

    for i, item in enumerate(train_loader):
        if i % 10000 == 0:
            print(i)
            
            
print('Starting ok')
ok()
print('ok done, starting leaky()')
leaky()
print('leaky done')

Environment

$ python3 collect_env.py
Collecting environment information...
PyTorch version: 0.4.0
Is debug build: No
CUDA used to build PyTorch: 9.1.85

OS: Ubuntu 16.04.5 LTS
GCC version: (Ubuntu 5.4.0-6ubuntu1~16.04.10) 5.4.0 20160609
CMake version: version 3.5.1

Python version: 3.5
Is CUDA available: Yes
CUDA runtime version: 9.1.85
GPU models and configuration: GPU 0: GeForce GTX 1050 Ti with Max-Q Design
Nvidia driver version: 390.77
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.1.2
/usr/lib/x86_64-linux-gnu/libcudnn_static_v7.a

Versions of relevant libraries:
[pip] Could not collect
[conda] Could not collect

@bfreskura
Copy link
Author

@ezyang

When we fix #13243 we should check if this one gets fixed too.

The issue is still present in 1.0.0.dev20181105, where the #13243 is fixed.

@bfreskura
Copy link
Author

bfreskura commented Nov 7, 2018

After some more investigation, I have found an exact scenario when the leak occurs. Consider the code example below:

from torch.utils.data import Dataset, DataLoader
import numpy as np
import torch


class DataIter(Dataset):
    def __init__(self):
        self.data_np = np.array([x for x in range(24000000)])
        self.data = [x for x in range(24000000)]

    def __len__(self):
        return len(self.data)

    def __getitem__(self, idx):
        data = self.data[idx]
        data = np.array([data], dtype=np.int64)
        return torch.tensor(data)


train_data = DataIter()
train_loader = DataLoader(train_data, batch_size=300,
                          shuffle=True,
                          drop_last=True,
                          pin_memory=False,
                          num_workers=18)

for i, item in enumerate(train_loader):
    if i % 1000 == 0:
        print(i)

If we use the self.data variable which is a standard Python list of ints, the data leak will occur. However, if the self.data_np variable is used, which holds the same data but in a form of a Numpy array, the leak will not occur.
Another observation is that the leakage is significantly less severe if the shuffle=False in the DataLoader.

@svishnu88
Copy link

I face similar issue, but in my case it occurs with numpy array too. I am using Python 3.7 and PyTorch nightly release.

@mprostock
Copy link

I don't know how multiprocessing really works under the hood of pytorch, but we have extensively discussed this "Memory Leak" issue (which probably isn't a memory leak!) on the fast.ai forums (https://forums.fast.ai/t/runtimeerror-dataloader-worker-is-killed-by-signal/31277/55?u=marcmuc). Preliminary findings which hopefully add some insight here (if this does NOT apply, please comment!):

Python Multiprocessing: There is no way of storing arbitrary python objects (even simple lists) in shared memory in Python without triggering copy-on-write behaviour due to the addition of refcounts, everytime something reads from these objects. The refcounts are added memory-page by memory-page, which is why the consumption grows slowly. The processes (workers) will end up having all/most of the memory copied over bit by bit, which is why we get the memory overflow problem. Best description of this behavior is here (SO).

Possible Solution:
Using Multiprocessing like now: in order for python multiprocessing to work without these refcount effects, the objects have to be made “compatible with” and wrapped in multiprocessing.Array before the process pool is created and workers are forked. This supposedly ensures, that the memory will really be shared and no copy-on-write happens. This explains how to do it for numpy arrays and this explains the reasoning behind it again. Don’t get confused by some false statements even by the authors of these good answers stating that copy-on-write makes all of this unnecessary, which is not true. One comment also points to this:

“Just to note, on Python fork() actually means copy on access (because just accessing the object will change its ref-count).”

I am not familiar with the torch.multiprocessing drop-in replacement that I understand pytorch uses, but I would assume it will also not be able to remove the core python refcount issue.

@soumith
Copy link
Member

soumith commented Dec 9, 2018

@mprostock torch.multiprocessing is simply Python multiprocessing, with a custom pickler. The custom pickler, whenever it encounters a torch.tensor, will automatically move it to shared memory, and hence atleast on the torch.tensor objects, no copy-on-write happens.

@mprostock
Copy link

Thanks for the explanation! I have experimented with @bfreskura 's reproduction example and I think I can now pinpoint the problem:

The reproduction example by bfreskura above showed the difference between a regular python list and a numpy array. But the problem is not (only) the python list itself, the same happens in a numpy array of type object. Python lists store only references to the objects, the objects are kept separately in memory. Every object has a refcount, therefore every item in the list has a refcount.

Numpy arrays (of standard np types) are stored as continuous blocks in memory and are only ONE object with one refcount.

This changes if you make the numpy array explicitly of type object, which makes it start behaving like a regular python list (only storing references to (string) objects). The same "problems" with memory consumption now appear.

This would explain, why with regular lists (or numpy arrays of type object) we see the "memory leak", which actually is the copy-on-acces problem of forked python processes due to changing refcounts, not a memory leak.

So the problem probably (often) has got nothing to do with tensors or actual torch objects, but rather with the lists of filenames and dicts of labels, that are generally used within dataloaders/datasets.

I have created a notebook gist, if someone wants to quickly try it.
Look at the memory consumption (quick and dirty mem of total system, so minor influences by other processes, tried to keep system clean)

Memory-Consumption in GB with fixed length string array:
image

Memory-Consumption in GB with object array (only change!)
image

@aurooj
Copy link

aurooj commented Jan 15, 2019

I am facing the same issue. It fills up my RAM very fast if the num_workers > 0.
I am deleting the variables which I feel are no longer needed in my code, also call gc.collect() on every iteration, but nothing helps.
Any workarounds?

@NProkoptsev
Copy link

Switching from dict to pandas and from lists to numpy arrays helps me

I am facing the same issue. It fills up my RAM very fast if the num_workers > 0.
I am deleting the variables which I feel are no longer needed in my code, also call gc.collect() on every iteration, but nothing helps.
Any workarounds?

@aurooj
Copy link

aurooj commented Jan 18, 2019

Thanks for the reply. I will try that and hopefully, it works.

@Godricly
Copy link

May I ask for the solution for this issue? I tried @samgd code on last daily built pytorch, and it was still leaking.

@ssnl
Copy link
Collaborator

ssnl commented Jan 25, 2019

@Godricly See @mprostock and @soumith 's comments above. This is not really a leak, but an unfortunate behavior of using python native list. Using either torch tensor or np array will solve this memory problem.

@ys198918
Copy link

@mprostock Do you mean that it is the copy created by copy-on-access use up the memory ,not something else? And doesn't the copy release after used?

@1e100
Copy link

1e100 commented Feb 24, 2019

Someone needs to step up and write a proper augmentation op for image datasets at least. The whole reason for all of these multiprocessing shenanigans is because vision datasets have to decode and crop images on multiple cores. If there was an op that took care of decoding and geometric image transforms (resize, crop flip, shear, affine), and produced batch tensors directly, there would be no need to use multiprocessing at all, and further, non-geometric augmentation steps (colors, whitening/normalization, noise) could use intra-op parallelism to rip through the entire tensor. Care needs to be taken when designing such an op to expose transform parameters for each sample in the tensor to the outside, in order to enable parallel transformation of annotations (bounding boxes, masks, keypoints, etc).
Or better yet, make this a server, so that multiple processes (as well as other DL frameworks) could use it as well.

@smolendawid
Copy link

@mprostock thank you for the great explanation!

However, no solution has been proposed yet. Storing lists of filenames in Dataset object seems fair, so how one can use them? Did anyone figure it out?

@ezyang ezyang added module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: dataloader Related to torch.utils.data.DataLoader and Sampler module: molly-guard Features which help prevent users from committing common mistakes and removed high priority labels Apr 2, 2019
johny-b added a commit to johny-b/dds-transformer that referenced this issue Aug 4, 2023
* Uses `ray` to make it faster
* Data is stored as tensors because of
  pytorch/pytorch#13246 (comment)
@ATheCoder
Copy link

ATheCoder commented Aug 21, 2023

What if I want to load my dataset completely as an attribute of the dataset object and then use that attribute inside the __getitem__. I tried doing this with a list of pytorch_geometric's Data object and I still got the growing memory problem. Does this mean that I should load them as a 1D tensor?

@vadimkantorov
Copy link
Contributor

For all people who stumbled on this recently, please consider upvoting the proposal #101699 in the part of introducing a tensor-backed array of strings into core (at least a read-only one)

@bhack
Copy link
Contributor

bhack commented Mar 3, 2024

Has upstream cypthon improved something on this topic?
python/cpython#84436

@Weifeng-Chen
Copy link

Weifeng-Chen commented Mar 3, 2024 via email

@underwoodnoble
Copy link

i try to init data in getitem() instead of init(). it work.

我在将datalist的初始化延迟到了__getitem__()中,最终部分解决了这个问题。但是每个worker都需要执行一遍init_data_list.。如果Dataloader的persistent_workers设置为False(建议设置为True), 那么每个epoch开始时每个worker都要执行一遍init_data_list,十分耗时。。但是如果将延迟初始化的思想如果运用到LMDB数据集中中,那么每个worker都只需要在lmdb数据集中添加一个事务就好了,应该会更完美。。。我也尝试了上面老哥的deepcopy的方法,可惜在我电脑上没有效果,。碰到这个问题的朋友可以两个方法都试试。


import torch class my_dataset(torch.utils.data.Dataset): def **init**(self,): self.data_list=[]

def init_data_list(self,):
self.data_list = torch.randn([100000])

def getitem(self,index):
if self.data_list() == []:
self.init_data_list()
return self.data_list[index]

Why not use pandas to store paths directly?

@Weifeng-Chen
Copy link

Weifeng-Chen commented Apr 18, 2024 via email

@zjuAJW
Copy link

zjuAJW commented Apr 28, 2024

So,should using pandas dataframe solve this problem?I read csv files into pandas dataframe and use iloc to indexing data in getitem but the leaking problem still exists. Here is my code.

class MyDataset(Dataset):
    def __init__(self, record_dir):
       dfs = []
       for i in os.listdir(record_dir):
             df = pd.read_csv(os.path.join(record_dir, i))
             dfs.append(df)
        self.df = pd.concat(dfs)
    def __len__(self):
        return len(self.df)

    def __getitem__(self, item):
       row = self.df.iloc[item]
       # data = some process 
       return data

The problem even exists when i just read a line from the dataframe but do not use it;

    def __getitem__(self, item):
       row = self.df.iloc[item]  # if this line is commented, the cpu memory doesn't leak
       return 1

@egorchakov
Copy link

FWIW we've (seemingly so far) worked around this issue by storing dataset state in a polars DataFrame and reading from it in Dataset.__getitems__.

@zichunxx
Copy link

zichunxx commented May 25, 2024

Hi! To avoid memory increase using IterableDataset, is there any alternative of Python dict to store and load data dynamically? Thanks.

@Weifeng-Chen
Copy link

Weifeng-Chen commented May 25, 2024 via email

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: dataloader Related to torch.utils.data.DataLoader and Sampler module: dependency bug Problem is not caused by us, but caused by an upstream library we use module: memory usage PyTorch is using more memory than it should, or it is leaking memory module: molly-guard Features which help prevent users from committing common mistakes module: multiprocessing Related to torch.multiprocessing triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests