Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add DataLoader related parameters to fit() and predict() #2295

Open
wants to merge 5 commits into
base: master
Choose a base branch
from

Conversation

BohdanBilonoh
Copy link
Contributor

Checklist before merging this PR:

  • Mentioned all issues that this PR fixes or addresses.
  • Summarized the updates of this PR under Summary.
  • Added an entry under Unreleased in the Changelog.

Summary

Add torch.utils.data.DataLoader related parameters to fit() and predict() of TorchForecastingModel

Copy link

codecov bot commented Apr 9, 2024

Codecov Report

All modified and coverable lines are covered by tests ✅

Project coverage is 93.74%. Comparing base (a0cc279) to head (ec7a01b).

Additional details and impacted files
@@            Coverage Diff             @@
##           master    #2295      +/-   ##
==========================================
- Coverage   93.75%   93.74%   -0.01%     
==========================================
  Files         138      138              
  Lines       14352    14346       -6     
==========================================
- Hits        13456    13449       -7     
- Misses        896      897       +1     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@madtoinou
Copy link
Collaborator

madtoinou commented May 6, 2024

Hi @BohdanBilonoh,

It looks great, however to make it easier to maintain and more exhaustive, I think that it would be great to just add an argument called dataloader_kwargs, then check that the argument explicitly used by Darts are not redundant/overwritten and then pass this argument down to the DataLoader constructor.

It will allow users to specify more than just prefetch_factor, persistent_workers and pin_memory, while limiting copy-pasting from other library documentation (putting a link to the torch.DataLoader page does sound like a good idea for this argument however)

PS: Apologies for taking so long with the review of this PR.

@joshua-xia
Copy link

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

@joshua-xia
Copy link

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

@BohdanBilonoh refer to #2375

@joshua-xia
Copy link

joshua-xia commented May 7, 2024

Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks!

@BohdanBilonoh refer to #2375

@BohdanBilonoh My bad, it is good idea from @madtoinou to add dataloader_kwargs to let user input dataloader parameters as wish freely, not need to support special multiprocessing_context parameter forcibly

@BohdanBilonoh
Copy link
Contributor Author

BohdanBilonoh commented May 28, 2024

@madtoinou what do you think about hardcoded parameters like

batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn,

should it be hard coded for new dataloader_kwargs like

def _setup_for_train(
    self,
    train_dataset: TrainingDataset,
    val_dataset: Optional[TrainingDataset] = None,
    trainer: Optional[pl.Trainer] = None,
    verbose: Optional[bool] = None,
    epochs: int = 0,
    dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:

        ...

        if dataloader_kwargs is None:
            dataloader_kwargs = {}

        dataloader_kwargs["shuffle"] = True
        dataloader_kwargs["batch_size"] = self.batch_size
        dataloader_kwargs["drop_last"] = False
        dataloader_kwargs["collate_fn"] = self._batch_collate_fn

        # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
        # least one batch no matter the chosen batch size
        train_loader = DataLoader(
            train_dataset,
            **dataloader_kwargs,
        )

        dataloader_kwargs["shuffle"] = False

        # Prepare validation data
        val_loader = (
            None
            if val_dataset is None
            else DataLoader(
                val_dataset,
                **dataloader_kwargs,
            )
        )

        ...

or give a user full control on dataloader_kwargs?

@tRosenflanz
Copy link
Contributor

@madtoinou what do you think about hardcoded parameters like

batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn,

should it be hard coded for new dataloader_kwargs like

def _setup_for_train(
    self,
    train_dataset: TrainingDataset,
    val_dataset: Optional[TrainingDataset] = None,
    trainer: Optional[pl.Trainer] = None,
    verbose: Optional[bool] = None,
    epochs: int = 0,
    dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:

        ...

        if dataloader_kwargs is None:
            dataloader_kwargs = {}

        dataloader_kwargs["shuffle"] = True
        dataloader_kwargs["batch_size"] = self.batch_size
        dataloader_kwargs["drop_last"] = False
        dataloader_kwargs["collate_fn"] = self._batch_collate_fn

        # Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
        # least one batch no matter the chosen batch size
        train_loader = DataLoader(
            train_dataset,
            **dataloader_kwargs,
        )

        dataloader_kwargs["shuffle"] = False

        # Prepare validation data
        val_loader = (
            None
            if val_dataset is None
            else DataLoader(
                val_dataset,
                **dataloader_kwargs,
            )
        )

        ...

or give a user full control on dataloader_kwargs?

you could extend your suggestion to allow overrides but with populated defaults

defaults = dict(shuffle = True, batch_size = self.batch_size, drop_last = False, collate_fn = self._batch_collate_fn)
#combine with defaults but override them
dataloader_kwargs_train = {**defaults,**dataloader_kwargs}
#override shuffle
dataloader_kwargs_val = (**dataloader_kwargs_train, **dict(shuffle=False)} 

Bohdan Bilonoh and others added 2 commits May 31, 2024 13:02
Copy link
Collaborator

@dennisbader dennisbader left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for the updates @BohdanBilonoh, I took the chance to fix some indentation issues in one of the docstrings and pushed the changes.

If we give the freedom to overwrite our default dataloader params when calling fit(), shouldn't we then also allow that during predict()?

Also, removing the num_loader_workers parameter is a breaking change. Can you document this in the CHANGELOG.md?

@@ -1487,14 +1484,17 @@ def predict_from_dataset(
mc_dropout=mc_dropout,
)

if data_loader_kwargs is None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Shouldn't we allow then here as well the liberty to overwrite these defaults (expect shuffle)?

A larger number of workers can sometimes increase performance, but can also incur extra overheads
and increase memory usage, as more batches are loaded in parallel.
data_loader_kwargs
Optionally, a dictionary of keyword arguments to pass to the PyTorch DataLoader instances used to load the
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

it's referring to the train and val datasets, but it should be for the prediction dataset, same for the methods below

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

5 participants