Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Rethinking the Queue class to get full GPU utilization #393

Open
dmus opened this issue Dec 22, 2020 · 17 comments
Open

Rethinking the Queue class to get full GPU utilization #393

dmus opened this issue Dec 22, 2020 · 17 comments
Labels
enhancement New feature or request

Comments

@dmus
Copy link
Contributor

dmus commented Dec 22, 2020

This is a really nice framework, however a serious issue for me is (the lack of) GPU utilization. This is an issue even with only a simple ZNormalization and a left right flip of the data as augmentation. This results in the following GPU utilization:
Screenshot from 2020-12-22 12-12-31

This is a training with 5 subjects and sampling 40 patches per volume and batch size 8. After every 25 iterations (which means 200 patches) there is a gap and the GPU utilization is 0.

What I did is building a custom Queue class in which I tried to get full GPU utilization. The result is as follows:
Screenshot from 2020-12-22 12-05-10
As you can see this has a GPU utilization of ~100% without gaps (0 utilization is before the start).

I tried to use the existing PyTorch data functionalities as much as possible. The BufferedShuffleDataset is something that is not yet in the release, but seems to be in the next release. The idea behind it is the same as for the shuffle() in Tensorflow data.

Here is the code that I made:

import random
from itertools import islice
from typing import Iterator, List

import torch
torch.multiprocessing.set_sharing_strategy('file_system')
from torch.utils.data import DataLoader, IterableDataset
from torch.utils.data.dataset import T_co
from torchio.data import PatchSampler



# https://github.com/pytorch/pytorch/commit/96540e918c4ca3f0a03866b9d281c34c65bd76a4#diff-425b66e1ff01d191679c386258a7156dfb5aacd64a8e0947b24fbdebcbee8529
class BufferedShuffleDataset(IterableDataset[T_co]):
    r"""Dataset shuffled from the original dataset.
    This class is useful to shuffle an existing instance of an IterableDataset.
    The buffer with `buffer_size` is filled with the items from the dataset first. Then,
    each item will be yielded from the buffer by reservoir sampling via iterator.
    `buffer_size` is required to be larger than 0. For `buffer_size == 1`, the
    dataset is not shuffled. In order to fully shuffle the whole dataset, `buffer_size`
    is required to be greater than or equal to the size of dataset.
    When it is used with :class:`~torch.utils.data.DataLoader`, each item in the
    dataset will be yielded from the :class:`~torch.utils.data.DataLoader` iterator.
    And, the method to set up a random seed is different based on :attr:`num_workers`.
    For single-process mode (:attr:`num_workers == 0`), the random seed is required to
    be set before the :class:`~torch.utils.data.DataLoader` in the main process.
        >>> ds = BufferedShuffleDataset(dataset)
        >>> random.seed(...)
        >>> print(list(torch.utils.data.DataLoader(ds, num_workers=0)))
    For multi-process mode (:attr:`num_workers > 0`), the random seed is set by a callable
    function in each worker.
        >>> ds = BufferedShuffleDataset(dataset)
        >>> def init_fn(worker_id):
        ...     random.seed(...)
        >>> print(list(torch.utils.data.DataLoader(ds, ..., num_workers=n, worker_init_fn=init_fn)))
    Arguments:
        dataset (IterableDataset): The original IterableDataset.
        buffer_size (int): The buffer size for shuffling.
    """
    dataset: IterableDataset[T_co]
    buffer_size: int

    def __init__(self, dataset: IterableDataset[T_co], buffer_size: int) -> None:
        super(BufferedShuffleDataset, self).__init__()
        assert buffer_size > 0, "buffer_size should be larger than 0"
        self.dataset = dataset
        self.buffer_size = buffer_size

    def __iter__(self) -> Iterator[T_co]:
        buf: List[T_co] = []
        for x in self.dataset:
            if len(buf) == self.buffer_size:
                idx = random.randint(0, self.buffer_size - 1)
                yield buf[idx]
                buf[idx] = x
            else:
                buf.append(x)
        random.shuffle(buf)
        while buf:
            yield buf.pop()


class PatchesDataset(IterableDataset):
    def __init__(self, subjects_dataset, sampler, samples_per_volume):
        self.subjects_dataset = subjects_dataset
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume

    def __iter__(self):
        while True:
            idx = random.randint(0, len(self.subjects_dataset) - 1)
            sample = self.subjects_dataset[idx]
            iterable = self.sampler(sample)
            patches = list(islice(iterable, self.samples_per_volume))

            yield patches


class Queue(IterableDataset):
    def __init__(
            self,
            subjects_dataset: PatchesDataset,
            max_length: int,
            samples_per_volume: int,
            sampler: PatchSampler,
            num_workers: int = 0,
            shuffle_subjects: bool = True,
            shuffle_patches: bool = True,
            verbose: bool = False,
    ):
        self.dataset = PatchesDataset(subjects_dataset, sampler, samples_per_volume)
        self.max_length = max_length

        self.loader = DataLoader(self.dataset,
                                 batch_size=None,
                                 num_workers=num_workers,
                                 persistent_workers=True)

        self.buffer = []

    def __iter__(self):
        # Basically this is an unbatch operation
        for patches_list in self.loader:
            for patch in patches_list:
                yield patch

As you can see I had to do:

import torch
torch.multiprocessing.set_sharing_strategy('file_system')

because otherwise I got this error: RuntimeError: received 0 items of ancdata

This may be something in the system that I use, but it seems to be a more common thing, see:
pytorch/pytorch#973

With using this custom implementation I could use:

        # patches_queue = tio.Queue(
        #     self.dataset,
        #     max_length=self.queue_length,
        #     samples_per_volume=self.samples_per_volume,
        #     sampler=sampler,
        #     num_workers=self.num_workers,
        #     verbose=False
        # )

        # use the custom queue instead of the default one
        queue = Queue(self.dataset,
                      max_length=self.queue_length,
                      samples_per_volume=self.samples_per_volume,
                      sampler=sampler,
                      num_workers=self.num_workers,
                      verbose=False)
        patches_queue = BufferedShuffleDataset(queue, self.queue_length)

        patches_loader = DataLoader(patches_queue, batch_size=self.batch_size)

What do you think of this? Could this replace or exist next to the existing tio.Queue?

@dmus dmus added the enhancement New feature or request label Dec 22, 2020
@fepegar
Copy link
Owner

fepegar commented Dec 22, 2020

Hi, @dmus. I haven't fully understood how this works, but it certainly looks promising. So would the samplers need to be reimplemented as well? How random is this? The way the current queue works is

  1. Shuffle subjects
  2. Extract N patches from subjects until the buffer is full
  3. Shuffle the buffer
  4. Pop patches from the buffer
  5. If the buffer is empty, go to 2.

This ensures that batches of patches come from different subjects. Is this also ensured? And would all the subjects be sampled as well? If this works well, we could synchronize it with the release of PyTorch 1.8.


What tool do you use to visualize GPU utilization?

@dmus
Copy link
Contributor Author

dmus commented Dec 22, 2020

The tool for GPU utilization is Grafana. This uses the existing samplers, no need for reimplementation.

How this custom queue works is:

  1. Select a random subject (in each worker)
  2. Extract N patches from the subject (in this implementation also in the worker)
  3. In the main process: fill the buffer with patches until the buffer is full
  4. Pick a random patch from the buffer and replace this with a new patch

This ensures that batches of patches come from different subjects. Is this also ensured?

Each worker picks a random subject, so in that way patches come from different subjects. The buffer size should be large enough to have samples from different subjects

I also tried with a map-style Dataset instead of an iterable dataset for the PatchesDataset. Then each subjects is sampled once, but I was getting (small) gaps in GPU utilization with this approach

And would all the subjects be sampled as well?

In theory when you sample long enough you should have about the same number of samples from each subject

@fepegar
Copy link
Owner

fepegar commented Dec 22, 2020

Nice. Using a map-style dataset would ensure that patches are extracted from each subject once, and would let us define an epoch, right? I'm not sure why that approach would be less optimized than the iterable dataset.

@dmus
Copy link
Contributor Author

dmus commented Dec 22, 2020

Nice. Using a map-style dataset would ensure that patches are extracted from each subject once, and would let us define an epoch, right?

Yes that is right

I'm not sure why that approach would be less optimized than the iterable dataset.

I also don't know. Maybe because after each epoch things have to be initialized again, but not sure

@fepegar
Copy link
Owner

fepegar commented Dec 22, 2020

The buffered shuffle dataset seems to be available in the nightly version:

 Py 3.8.5 (test)  ~  pip install numpy                                                                                                      ✔  26% (2:18)  24.63s  13:49:21
pip install --pre torch torchvision -f https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
Collecting numpy
  Using cached numpy-1.19.4-cp38-cp38-macosx_10_9_x86_64.whl (15.3 MB)
Installing collected packages: numpy
Successfully installed numpy-1.19.4
Looking in links: https://download.pytorch.org/whl/nightly/cpu/torch_nightly.html
Collecting torch
  Downloading https://download.pytorch.org/whl/nightly/cpu/torch-1.8.0.dev20201222-cp38-none-macosx_10_9_x86_64.whl (115.5 MB)
     |████████████████████████████████| 115.5 MB 15 kB/s
Requirement already satisfied: numpy in /usr/local/Caskroom/miniconda/base/envs/test/lib/python3.8/site-packages (from torch) (1.19.4)
Collecting torchvision
  Downloading https://download.pytorch.org/whl/nightly/torchvision-0.9.0.dev20201222-cp38-cp38-macosx_10_9_x86_64.whl (13.1 MB)
     |████████████████████████████████| 13.1 MB 7.8 MB/s
Collecting pillow>=4.1.1
  Using cached Pillow-8.0.1-cp38-cp38-macosx_10_10_x86_64.whl (2.2 MB)
Collecting typing-extensions
  Using cached typing_extensions-3.7.4.3-py3-none-any.whl (22 kB)
Installing collected packages: typing-extensions, torch, pillow, torchvision
Successfully installed pillow-8.0.1 torch-1.8.0.dev20201222 torchvision-0.9.0.dev20201222 typing-extensions-3.7.4.3
 Py 3.8.5 (test)  ~  ipython                                                                                                                 ✔  26% (2:20)  01:09  13:50:46
import Python 3.8.5 (default, Sep  4 2020, 02:22:02)
Type 'copyright', 'credits' or 'license' for more information
IPython 7.19.0 -- An enhanced Interactive Python. Type '?' for help.

In [1]: import torch

In [2]: torch.__version__
Out[2]: '1.8.0.dev20201222'

In [3]: from torch.utils.data.dataset import BufferedShuffleDataset

I won't have much time in the next two months. Would you like to submit a PR with your proposal? I thought that BufferedShuffleQueue might be appropriate, but I'm not sure because the current queue is also buffered and typically shuffled.

@dmus
Copy link
Contributor Author

dmus commented Dec 22, 2020

I will test and work it out a bit more and then submit a PR

@dmus
Copy link
Contributor Author

dmus commented Apr 22, 2021

To come back to this issue. The implementation now looks like this

class BatchedPatchesDataset(IterableDataset):

    def __init__(self, subjects_datasets, weights, sampler, samples_per_volume):
        self.subjects_datasets = subjects_datasets
        self.weights = weights
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume

    def __iter__(self):
        while True:
            sampled_dataset = random.choices(population=self.subjects_datasets, weights=self.weights)[0]
            idx = random.randint(0, len(sampled_dataset) - 1)
            sample = sampled_dataset[idx]
            iterable = self.sampler(sample)
            patches = list(islice(iterable, self.samples_per_volume))

            yield patches


class UnbatchDataset(IterableDataset):

    def __init__(
            self,
            dataset: Dataset,
            num_workers: int = 0,
    ):
        self.loader = DataLoader(dataset,
                                                 batch_size=None,
                                                 num_workers=num_workers)

    def __iter__(self):
        for batch in self.loader:
            yield from batch

To use it:

    # This yields a dataset with non random batches of batch_size 'samples_per_volume' 
    to_patches = BatchedPatchesDataset(subjects_datasets=[dataset],
                                                                  weights=[1], # Only relevant when sampling from multiple subject datasets
                                                                  sampler=sampler,
                                                                  samples_per_volume=samples_per_volume)
    # Unbatch the batches
    patches_unbatched = UnbatchDataset(to_patches, num_workers)

    # Shuffle to get the patches in a random order
    queue = BufferedShuffleDataset(patches_unbatched, max_queue_length)

    patches_loader = DataLoader(queue, batch_size=batch_size)

    for i, patches_batch in enumerate(patches_loader):
        inputs = patches_batch['ct'][tio.DATA].numpy()
        targets = patches_batch['labels'][tio.DATA].numpy()

For my use case this gives good gpu utilization and a big speed up. But it would be good to also know about other use cases. Shall I submit a pull request?

@fepegar
Copy link
Owner

fepegar commented May 11, 2021

Hi, @dmus. This looks interesting. It's nice to be able to leverage newer PyTorch classes. I have some questions:

  1. Using multiple datasets to sample one randomly seems like a personal use case. Is there a way to generalize this to pass just one?
  2. This system samples a random subject from a random dataset, ignoring the traditional concept of epochs. Would it be possible to ensure that patches from all subjects are extracted once and only once at each for loop over the new class?
  3. I find the nomenclature a bit confusing. If I understand correctly:
    a. BatchedPatchesDataset returns lists of instances of Subject, in which each Subject is actually a patch sampled from the original subject.
    b. What does UnbatchDataset return? If batch_size is None, then I guess batch is actually an instance of Subject, which inherits from dict, so iterating over it would just return the keys -> I'm confused and probably got something wrong. Why is this called UnbatchDataset, what is "unbatching" and why would you do that?

In summary, it would be great to make this sample once from each subject at each epoch and to understand a bit what each class is doing.

@dmus
Copy link
Contributor Author

dmus commented May 15, 2021

  1. Yes, this can be changed to accept both one dataset and a list of datasets with corresponding weights.
  2. This would be difficult I think (and for more complex use cases with datasets from multiple centers probably not desired, how to define a traditional epoch there). At least this would require using map-style datasets instead of iterable datasets I think.
  3. a. correct
    b. it iterates over the list from a. and returns the individual items (a patch sampled from the original subject, not a list anymore). Unbatch does what is described here too: https://www.tensorflow.org/api_docs/python/tf/data/Dataset#unbatch

@rousseau
Copy link

As found by dmus, we have also encoutered this "issue" (i.e. not full use of GPU).

Is there a way to use TorchIO with queues optimized for GPU ?

@romainVala
Copy link
Contributor

no solution has been incorporated into current torchio,
but did you test dmus 's solution
it worth to try, although not easy to choose the best parameters

@rousseau
Copy link

I've tried to use dmus's solution. GPU is fully used but the training does not end !

To avoid the gaps mentioned by dmus, within current TorchIO, what do you recommend ?

  • Largest queue as possible ?
  • Many workers as possible ?
  • ... ?

@fepegar
Copy link
Owner

fepegar commented Sep 22, 2021

Hi, @rousseau. I do recommend a long queue, multiple workers (sometimes fewer than the maximum is faster, for example I typically use 12/40 in a DGX after benchmarking), many samples per volume, and fast transforms (avoid RandomElasticDeformation, RandomMotion, maybe RandomBiasField). If you have a lot of storage, you can store your images uncompressed (e.g. .nii instead of .nii.gz) or, even better, preload the preprocessed images on the RAM for faster access. Just watch out for the RAM usage. You can use get_max_memory_pretty. But remember some operations (e.g. the transforms) will also need RAM, so don't use all of it for the queue.

@ramonemiliani93
Copy link
Contributor

ramonemiliani93 commented Oct 20, 2021

Building on top of what @dmus suggested and following what @fepegar implemented on the Queue I have something like this:

import random
from itertools import islice

import torch
import torchio as tio
from torch.utils.data import IterableDataset


class Queue(IterableDataset):
    def __init__(
        self,
        subjects_dataset: tio.SubjectsDataset,
        sampler: tio.data.PatchSampler,
        samples_per_volume: int,
        shuffle_subjects: bool = True,
        shuffle_patches: bool = True,
        buffer_size: int = 0,
    ):
        if shuffle_patches and not buffer_size:
            m = "The `buffer_size` parameter must be defined when shuffling patches."
            raise ValueError(m)

        self.subjects_dataset = subjects_dataset
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume
        self.shuffle_subjects = shuffle_subjects
        self.shuffle_patches = shuffle_patches
        self.buffer_size = buffer_size

    @property
    def num_subjects(self) -> int:
        return len(self.subjects_dataset)

    @property
    def iterations_per_epoch(self) -> int:
        return self.num_subjects * self.samples_per_volume

    def __len__(self):
        return self.iterations_per_epoch

    def __iter__(self):
        worker_info = torch.utils.data.get_worker_info()

        num_workers, worker_id = 1, 0
        if worker_info is not None:
            num_workers, worker_id = worker_info.num_workers, worker_info.id

        # Subject indices to get from the dataset.
        indices = self.split_among_workers(self.num_subjects, num_workers, worker_id)

        # For the buffer size we split the total buffer size among each worker.
        buffer_rng = self.split_among_workers(self.buffer_size, num_workers, worker_id)
        buffer_size = buffer_rng.stop - buffer_rng.start  # Need the size not range
        buffer = []

        if self.shuffle_subjects:
            indices = random.sample(indices, indices.stop - indices.start)

        for index in indices:
            subject = self.subjects_dataset[index]
            patches = islice(self.sampler(subject), self.samples_per_volume)

            if not self.shuffle_patches:
                yield from patches
                continue

            for patch in patches:
                if len(buffer) == buffer_size:
                    idx = random.randint(0, buffer_size - 1)
                    yield buffer[idx]
                    buffer[idx] = patch
                else:
                    buffer.append(patch)

        random.shuffle(buffer)
        while buffer:
            yield buffer.pop()

    @staticmethod
    def split_among_workers(n: int, num_workers: int, worker_id: int) -> range:
        """
        Generates a range of indices up to `n` assigned to worker with id `worker_id`
        from a total pool of `num_workers`.

        For a single worker it will be composed of the full range:
            >>> Queue.split_among_workers(5, 1, 0)
            range(0, 5)

        For multiple workers it will depend on the id assigned to the worker. In case
        the total number is not divisible by the number of workers the remaining m
        values, for m < num_workers, will be assigned to the first m workers:
            >>> Queue.split_among_workers(5, 2, 0)
            range(0, 3)
            >>> Queue.split_among_workers(5, 2, 1)
            range(2, 5)
        """
        if worker_id >= num_workers:
            m = (
                f"The worker id provided `{worker_id}` must be less than the total "
                f"number of workers `{num_workers}`"
            )
            raise ValueError(m)

        per_worker, remaining = divmod(n, num_workers)
        start = per_worker * worker_id + min(remaining, worker_id)
        end = start + per_worker + remaining // (worker_id + 1)

        return range(start, end)

This implementation will ensure each of the subjects is sampled once and only once per epoch with the given number of samples_per_volume. This follows what is suggested on the IterableDataset class in PyTorch where the indices are equally split among the number of workers without having to share objects between processes. It is worth noting that this approach has some caveats:

  1. The indices are split among the workers hence if by chance one of the workers ends up before it will remain idle until the end of the loop. One option would be to define a torch.multiprocessing.SimpleQueue on the constructor and then poll from it from each of the workers (this implies reloading the dataset at each epoch end).

  2. The buffer_size is also equally split among the workers and each worker will have an independent buffer. I first tried a to use the torch.multiprocessing.Manager to create a share common buffer (manager.list()) between them but had memory leaks that made it impossible to go past a certain epoch. I also tried wrapping the code without the buffer around a DataLoader and the ShuffleIterDataPipe but the GPU utilization was not as good.

Any suggestions are welcome and If anyone manages to try it out and measure their GPU utilization let me know 🙌 The other approach with the torch.multiprocessing.SimpleQueue is something like this (no patch shuffling):

class Queue(IterableDataset):
    def __init__(
        self,
        subjects_dataset: tio.SubjectsDataset,
        sampler: tio.sampler.PatchSampler,
        samples_per_volume: int,
        shuffle_subjects: bool = True,
    ):
        self.subjects_dataset = subjects_dataset
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume
        self.shuffle_subjects = shuffle_subjects

        indices = range(len(self.subjects_dataset))
        if self.shuffle_subjects:
            indices = random.sample(indices, indices.stop - indices.start)

        self._indices_queue = SimpleQueue()
        for index in indices:
            self._indices_queue.put(index)

    def __iter__(self):
        while not self._indices_queue.empty():
            index = self._indices_queue.get()
            subject = self.subjects_dataset[index]
            yield from islice(self.sampler(subject), self.samples_per_volume)

Note that it needs recreating the Queue each epoch for the SimpleQueue to be filled.

@dmus
Copy link
Contributor Author

dmus commented Jan 27, 2022

Implementation is updated in the meantime to use the Torch DataPipes (see https://github.com/pytorch/data):

from itertools import islice

import torchio as tio
from torch.utils.data import DataLoader
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import Shuffler, UnBatcher

batch_size = 32
num_workers = 8
samples_per_volume = 20
sampler = tio.UniformSampler(patch_size=64)


class PatchesSampler(IterDataPipe):

    def __init__(self, datapipe, sampler, samples_per_volume):
        self.datapipe = datapipe
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume

    def __iter__(self):
        for subject in self.datapipe:
            iterable = self.sampler(subject)
            yield list(islice(iterable, self.samples_per_volume))  # in my experience this turned out to be faster than using yield from islice(iterable, self.samples_per_volume)) and removing the UnBatcher

datapipe = PatchesSampler(my_dataset, sampler, samples_per_volume)
datapipe = DataLoader(datapipe, batch_size=None, num_workers=num_workers)

datapipe = UnBatcher(datapipe)
datapipe = Shuffler(datapipe, buffer_size=batch_size * samples_per_volume)
dataloader = DataLoader(datapipe, batch_size=batch_size)

while True:
    for batch in dataloader:
        ...

@Paddy-Xu
Copy link

Paddy-Xu commented Sep 2, 2022

Implementation is updated in the meantime to use the Torch DataPipes (see pytorch/data):

from itertools import islice

import torchio as tio
from torch.utils.data import DataLoader
from torch.utils.data import IterDataPipe
from torch.utils.data.datapipes.iter import Shuffler, UnBatcher

batch_size = 32
num_workers = 8
samples_per_volume = 20
sampler = tio.UniformSampler(patch_size=64)


class PatchesSampler(IterDataPipe):

    def __init__(self, datapipe, sampler, samples_per_volume):
        self.datapipe = datapipe
        self.sampler = sampler
        self.samples_per_volume = samples_per_volume

    def __iter__(self):
        for subject in self.datapipe:
            iterable = self.sampler(subject)
            yield list(islice(iterable, self.samples_per_volume))  # in my experience this turned out to be faster than using yield from islice(iterable, self.samples_per_volume)) and removing the UnBatcher

datapipe = PatchesSampler(my_dataset, sampler, samples_per_volume)
datapipe = DataLoader(datapipe, batch_size=None, num_workers=num_workers)

datapipe = UnBatcher(datapipe)
datapipe = Shuffler(datapipe, buffer_size=batch_size * samples_per_volume)
dataloader = DataLoader(datapipe, batch_size=batch_size)

while True:
    for batch in dataloader:
        ...

Hi, Thank you very much! This does increase the utilization. However, It seems that I have to always set num_workers=0 in this case, and could you please explain a bit why the first batchsize in DataLoader is None and buffer size is batch_size * samples_per_volume? Thanks a lot!

@tiago972
Copy link

tiago972 commented Sep 2, 2022

First of all, thank you for this wrapper package and to @dmus to assess this issue as this has been a major limit in my case.

However, using the last proposition, if I use num_workers > 0 in datapipe = DataLoader(datapipe, batch_size=None, num_workers=num_workers) this raises :
RuntimeError: unable to mmap 192 bytes from file </torch_2753772_2515325872_1186>: Cannot allocate memory (12)

However, this is not raise with num_workers = 0 or if num_workers > 0 and pin_memory = True but the latter leads to memory leak.

This is not a memory issue as I have access to 255 CPUs and a NVIDIA-SMI 510.47.03 with 40Gb of memory.

Have you encounter such issue?

Chears,
Tiago

-- Edit --
I figured the solution but didn't find the source of the problem. The server I'm using has 255 CPU so I used num_workers = 100 which raised the error. Hovewever, after trying progressively inscreasing numbers of CPU, the error was not raised until 35 CPU was set.
I'm still curious about other configuration as this parameter actually slows down the training time: epoch took 27 minutes with full time GPU at 100% activity when num_workers is set to 20 instead of 4 minutes when num_workers is set to 0 but with spikes of GPU activityto 100% for about 5 secs and arround 30 to 45 seconds of 0% activity.

Chears, Tiago

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

7 participants