Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Easier way to use Data Processing steps outside of datamodule #1780

Open
nilsleh opened this issue Dec 18, 2023 · 4 comments
Open

Easier way to use Data Processing steps outside of datamodule #1780

nilsleh opened this issue Dec 18, 2023 · 4 comments
Labels
datamodules PyTorch Lightning datamodules

Comments

@nilsleh
Copy link
Collaborator

nilsleh commented Dec 18, 2023

Summary

Normalization and Augmentations are defined in the on_after_batch_transfer() function of the Datamodules to compute them on GPU like recommended from lightning. However, a downside of this is that you always have to pass the datamodule ino .fit and .test. Especially, for the latter, it can be convenient to test on separate dataloaders, however, those are then just "raw" dataloaders without normalization etc. being applied. Took me a minute to find that this was the reason for the funky test results. Currently, I am writing a custom collate_fn and set it to the dataloader that I am getting from a datamodule, however, it would be nice if this could be handled more easily. Open to hear thoughts about this, or suggestions for an easier ways to handle this than what I am doing at the moment.

Rationale

Sometimes I would like to test a model on different datasets and if a torchgeo datamodule is available, it is convenient to just retrieve a configured dataloder from an implemented datamodule.

Implementation

Maybe it could be possible to add a flag to return a dataloader with a collate function based on the on_afer_batch_transfer augmentation.

Alternatives

Currently I am doing something like this:

datamodule = ETCI2021DataModule(root=".", download=True, num_workers=4, batch_size=32)
datamodule.setup("fit")


def collate(batch: list[dict[str, torch.Tensor]]):
    """Collate fn to include augmentations."""
    images = [item["image"] for item in batch]
    labels = [item["label"] for item in batch]

    inputs = torch.stack(images)
    targets = torch.stack(labels)
    return datamodule.on_after_batch_transfer({"image": inputs, "mask": targets})

val_dataloader = datamodule.val_dataloader()
val_dataloader.collate_fn = collate
@adamjstewart
Copy link
Collaborator

I can understand why you would want to be able to use a dataset if a data module doesn't exist, but why would you want to use a dataset if a data module does exist?

@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 18, 2023

In order to do trainer.validate(model, dataloaders=datamodule.val_dataloader()) but not having to implement my own normalization scheme as a collate fn for every dataloader from a datamodule I want to use. So for example say I train one model and want to validate it on a bunch of datasets, then I could pass multiple dataloaders from different datasets or datamodules to trainer.validate()

@adamjstewart
Copy link
Collaborator

But why not use trainer.validate(model, datamodule=datamodule) for all data modules?

@nilsleh
Copy link
Collaborator Author

nilsleh commented Dec 18, 2023

If you pass a datamodule, it will only select the predefined validation loader and validate on that, but maybe I would like to validate on the train set and the validation set, for example when taking a pre-trained model and checking performance without training. Might also be relevant if you try something like cross validation, where you split your train/val sets. In my case, I am trying conformal prediction, where you need to take a subset of the validation set to create a separate calibration set and use the the model with that, so you need to control "which" split dataloader to apply validation to.

@adamjstewart adamjstewart added the datamodules PyTorch Lightning datamodules label Dec 18, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
datamodules PyTorch Lightning datamodules
Projects
None yet
Development

No branches or pull requests

2 participants