-
Notifications
You must be signed in to change notification settings - Fork 810
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
add DataLoader
related parameters to fit()
and predict()
#2295
base: master
Are you sure you want to change the base?
add DataLoader
related parameters to fit()
and predict()
#2295
Conversation
b1ec50c
to
3503b41
Compare
Codecov ReportAll modified and coverable lines are covered by tests ✅
Additional details and impacted files@@ Coverage Diff @@
## master #2295 +/- ##
==========================================
- Coverage 93.75% 93.74% -0.01%
==========================================
Files 138 138
Lines 14352 14346 -6
==========================================
- Hits 13456 13449 -7
- Misses 896 897 +1 ☔ View full report in Codecov by Sentry. |
Hi @BohdanBilonoh, It looks great, however to make it easier to maintain and more exhaustive, I think that it would be great to just add an argument called It will allow users to specify more than just PS: Apologies for taking so long with the review of this PR. |
Hi @BohdanBilonoh, Would you please add multiprocessing_context parameter for Dataloader, it is useful when we use multi-workers for dataloader, Thanks! |
@BohdanBilonoh refer to #2375 |
@BohdanBilonoh My bad, it is good idea from @madtoinou to add dataloader_kwargs to let user input dataloader parameters as wish freely, not need to support special multiprocessing_context parameter forcibly |
…predict()` of `TorchForecastingModel`
719b7ab
to
3a492df
Compare
@madtoinou what do you think about hardcoded parameters like batch_size=self.batch_size,
shuffle=True,
drop_last=False,
collate_fn=self._batch_collate_fn, should it be hard coded for new def _setup_for_train(
self,
train_dataset: TrainingDataset,
val_dataset: Optional[TrainingDataset] = None,
trainer: Optional[pl.Trainer] = None,
verbose: Optional[bool] = None,
epochs: int = 0,
dataloader_kwargs: Optional[Dict[str, Any]] = None,
) -> Tuple[pl.Trainer, PLForecastingModule, DataLoader, Optional[DataLoader]]:
...
if dataloader_kwargs is None:
dataloader_kwargs = {}
dataloader_kwargs["shuffle"] = True
dataloader_kwargs["batch_size"] = self.batch_size
dataloader_kwargs["drop_last"] = False
dataloader_kwargs["collate_fn"] = self._batch_collate_fn
# Setting drop_last to False makes the model see each sample at least once, and guarantee the presence of at
# least one batch no matter the chosen batch size
train_loader = DataLoader(
train_dataset,
**dataloader_kwargs,
)
dataloader_kwargs["shuffle"] = False
# Prepare validation data
val_loader = (
None
if val_dataset is None
else DataLoader(
val_dataset,
**dataloader_kwargs,
)
)
... or give a user full control on |
you could extend your suggestion to allow overrides but with populated defaults
|
- add predefined defaults
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Thanks for the updates @BohdanBilonoh, I took the chance to fix some indentation issues in one of the docstrings and pushed the changes.
If we give the freedom to overwrite our default dataloader params when calling fit()
, shouldn't we then also allow that during predict()
?
Also, removing the num_loader_workers
parameter is a breaking change. Can you document this in the CHANGELOG.md?
@@ -1487,14 +1484,17 @@ def predict_from_dataset( | |||
mc_dropout=mc_dropout, | |||
) | |||
|
|||
if data_loader_kwargs is None: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Shouldn't we allow then here as well the liberty to overwrite these defaults (expect shuffle)?
A larger number of workers can sometimes increase performance, but can also incur extra overheads | ||
and increase memory usage, as more batches are loaded in parallel. | ||
data_loader_kwargs | ||
Optionally, a dictionary of keyword arguments to pass to the PyTorch DataLoader instances used to load the |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
it's referring to the train and val datasets, but it should be for the prediction dataset, same for the methods below
Checklist before merging this PR:
Summary
Add
torch.utils.data.DataLoader
related parameters tofit()
andpredict()
ofTorchForecastingModel