Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

webdataset cannot stop cycling at end of epoch #5441

Open
CoinCheung opened this issue Apr 20, 2024 · 11 comments
Open

webdataset cannot stop cycling at end of epoch #5441

CoinCheung opened this issue Apr 20, 2024 · 11 comments
Assignees
Labels
enhancement New feature or request

Comments

@CoinCheung
Copy link

CoinCheung commented Apr 20, 2024

Version

1.31.0

Describe the bug.

I used dataset of about 2w samples, and the iteration of data should stop at iteration of 700. However, the dataloader would continue feed dataset batches after than, and the training will not stop.

Minimum reproducible example

Here is a piece of my code, which is the main part of dataloader:


import os.path as osp
import re
import time
import random

import numpy as np

from nvidia.dali.plugin.pytorch import DALIClassificationIterator, LastBatchPolicy, DALIGenericIterator
from nvidia.dali.pipeline import pipeline_def
import nvidia.dali.types as types
import nvidia.dali.fn as fn




@pipeline_def
def create_dali_pipeline_segment(wds_paths, shard_id, num_shards, dali_cpu=False,
                                 scales=[0.75, 2], cropsize=[1024, 1024],
                                 mean=[0.3257, 0.3690, 0.3223],
                                 std=[0.2112, 0.2148, 0.2115],
                                 ):

    wds_index_paths = [re.sub('tar$', 'idx', el) for el in wds_paths]
    images = fn.readers.webdataset(
        paths=wds_paths,
        index_paths=wds_index_paths,
        ext=['jpg',], missing_component_behavior="error",
        dtypes=[types.UINT8, ],
        random_shuffle=True,
        pad_last_batch=False,
        prefetch_queue_depth=4,
        shard_id=shard_id,
        num_shards=num_shards,
        read_ahead=True,
        device='cpu'
    )


    dali_device = 'cpu' if dali_cpu else 'gpu'
    decoder_device = 'cpu' if dali_cpu else 'mixed'
    # ask nvJPEG to preallocate memory for the biggest sample in ImageNet for CPU and GPU to avoid reallocations in runtime
    device_memory_padding = 211025920 if decoder_device == 'mixed' else 0
    host_memory_padding = 140544512 if decoder_device == 'mixed' else 0
    # ask HW NVJPEG to allocate memory ahead for the biggest image in the data set to avoid reallocations in runtime
    preallocate_width_hint = 5980 if decoder_device == 'mixed' else 0
    preallocate_height_hint = 6430 if decoder_device == 'mixed' else 0



    ## decode and switch to gpu
    shape = fn.peek_image_shape(images)
    images = fn.decoders.image(images, device='mixed', output_type=types.RGB)
    #  shape = fn.shapes(images)
    images = images.gpu()

    # random resize
    scale = fn.random.uniform(range=(min(scales), max(scales)))
    new_size = shape[0:2] * scale
    images = fn.resize(images, size=new_size,
                       interp_type=types.DALIInterpType.INTERP_LINEAR, antialias=False)

    # random crop
    crop_pos_x = fn.random.uniform(range=(0, 1))
    crop_pos_y = fn.random.uniform(range=(0, 1))
    images = fn.crop(images, crop=cropsize, crop_pos_x=crop_pos_x, crop_pos_y=crop_pos_y, out_of_bounds_policy="pad", fill_values=0)


    images = fn.transpose(images, perm=[2, 0, 1])
    images = fn.normalize(
        images,
        dtype=types.FLOAT,
        mean=255 * np.array(mean).reshape(-1, 1, 1),
        stddev=255 * np.array(std).reshape(-1, 1, 1))


    return images,


class OneEpochWraper(object):

    def __init__(self, dl, n_epochs):
        self.dl = iter(dl)
        self.n_epochs = n_epochs
        self.epoch = 0
        self.it = 0

    def __iter__(self):
        self.epoch += 1
        return self

    def __next__(self):
        print('iter: ', self.it)
        self.it += 1
        try:
            return next(self.dl)
        except StopIteration:
            print('epoch done: ', self.epoch)
            #  self.dl = iter(dl)
            #  if self.epoch >= self.n_epochs:
            #      raise StopIteration
            #  return next(self.dl)



def create_dali_loader(cfg, mode='train'):

    rank = int(os.environ['RANK'])
    local_rank = int(os.environ['LOCAL_RANK'])
    world_size = int(os.environ['WORLD_SIZE'])
    local_world_size = int(os.environ['LOCAL_WORLD_SIZE'])

    im_root = cfg.im_root
    im_anno = cfg.train_im_anns

    batchsize = cfg.global_batchsize #// world_size
    n_epochs = cfg.n_epochs

    dali_num_threads = 8

    saroot = '../../../datasets_share_to_all/SA-1B/raw/'
    wds_paths = [
        osp.join(saroot, 'sa_000198.coin.tar'),
        osp.join(saroot, 'sa_000199.coin.tar'),
    ]

    pipe = create_dali_pipeline_segment(batch_size=batchsize,
                                wds_paths=wds_paths,
                                num_threads=dali_num_threads,
                                device_id=local_rank,
                                seed=12 + local_rank,
                                dali_cpu=False,

                                prefetch_queue_depth=16,
                                shard_id=local_rank,
                                num_shards=world_size,
                                **cfg.dali_pipe_kwargs,
                                         )
    pipe.build()
    data_loader = DALIGenericIterator(pipe,
                                      ['data_0', ],
                                      last_batch_policy=LastBatchPolicy.DROP,
                                      auto_reset=False)

    n_samples = pipe.epoch_size()['__Webdataset_0']
    n_iters = (n_samples // cfg.global_batchsize) * cfg.n_epochs ## this is 700

    #  data_iter = OneEpochWraper(data_loader, cfg.n_epochs)
    data_iter = data_loader

    return data_iter, n_iters

dl, _ = create_dali_loader(...)

for it, data in enumerate(dl):
    print(it)  # this will not stop at 700, even not until 1000+


### Relevant log output

_No response_

### Other/Misc.

_No response_

### Check for duplicates

- [X] I have searched the [open bugs/issues](https://github.com/NVIDIA/DALI/issues) and have found no duplicates for this bug report
```[tasklist]
### Tasks
@CoinCheung CoinCheung added the bug Something isn't working label Apr 20, 2024
@JanuszL
Copy link
Contributor

JanuszL commented Apr 22, 2024

Hi @CoinCheung,

Thank you for reaching out.
DALI readers are infinite data sources and the DALIGenericIterator is not aware of how many samples are in the pipeline.
To enable the connection between the reader and the iterator please use reader_name argument, like in this toy example so the iterator can query the reader for the number of samples, and stop iterating accordingly.

@CoinCheung
Copy link
Author

@JanuszL

Thanks for replying !!

I have one more question. If there are more than one tar files given to webdataset , how will dali shuffle the samples? Will the shuffle operation be carried out amoung tar files or the order within one tar file is not changed?

@JanuszL
Copy link
Contributor

JanuszL commented Apr 22, 2024

I have one more question. If there are more than one tar files given to webdataset , how will dali shuffle the samples? Will the shuffle operation be carried out amoung tar files or the order within one tar file is not changed?

You can find more details in this answer.
Long story short, DALI uses an internal buffer of fixed size (initial_fill parameter) where data is read sequentially, and then when the batch is created this buffer is randomly sampled. If multiple files are used, they are read in sequence, one after another.

@CoinCheung
Copy link
Author

Hi @JanuszL ,

Just to make sure I have got your point. Does this mean that, different tar files are loaded sequentially, but within each tar file the samples are shuffled?

@JanuszL
Copy link
Contributor

JanuszL commented Apr 22, 2024

Just to make sure I have got your point. Does this mean that, different tar files are loaded sequentially, but within each tar file the samples are shuffled?

Samples are shuffled inside an internal buffer that is sequentially filed. When DALI reads one tar it moves to the next one, so samples from different tars can land inside one batch but the bigger the distance between samples in tars the less likely it is.

@CoinCheung
Copy link
Author

Thanks for telling me this!!!

I have a suggestion, maybe at the beginning of each epoch we can shuffle the order of the tar files. After the order of tar files are shuffled, we carry out the aforementioned sequential-and-random-buffer loading operation. This would add more randomness to the batches.

If this feature is reasonable, please consider adding it in future versions.

@JanuszL
Copy link
Contributor

JanuszL commented Apr 22, 2024

If this feature is reasonable, please consider adding it in future versions.

Thank you for your suggestion.
We will add it to our ToDo list. If you feel that you can contribute to enable this functionality we would be more than happy to assist in preparing a corresponding PR.

@JanuszL JanuszL added this to ToDo in Users requests via automation Apr 22, 2024
@JanuszL JanuszL added enhancement New feature or request and removed bug Something isn't working labels Apr 22, 2024
@CoinCheung
Copy link
Author

I am closing this since my question is answered. I am sorry that I am not able to contribute now. Thanks again for your help to the community !!!

Users requests automation moved this from ToDo to Done Apr 22, 2024
@CoinCheung CoinCheung reopened this Apr 22, 2024
Users requests automation moved this from Done to ToDo Apr 22, 2024
@JanuszL
Copy link
Contributor

JanuszL commented May 8, 2024

Hi @CoinCheung,

When I try to train the model with webdataset dataloader, I have to wait for about 4 hours before the first batch is fetched.

The reader first fills its internal buffer of initial_fill size. In your case it means downloading 1024 files of 2MB each, it may take a bit depending on the storage IO speed but rather not 4 hours. Can you capture NSight profile so that based on DALI annotations we can see what is going on?

@CoinCheung
Copy link
Author

Hi @JanuszL , I got the reason. My platform has 1T memory, but the dataset size is 10T, and I assigned read_ahead=True, which is the cause of the long time. After using read_ahead=False, the first batch comes much faster, and the problem does not appear again.

@JanuszL
Copy link
Contributor

JanuszL commented May 8, 2024

Yes, I missed the usage of read_ahead in your code. By default DALI mmaps files are read into memory, and read_ahead makes sure that the whole file content is read to RAM before the data is available (this translates to MAP_POPULATE option of mmap). It slows down the first iteration for sure.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
Development

No branches or pull requests

3 participants