Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make_train_dataloader discards custom collate function passed as kwarg #1531

Open
niniack opened this issue Nov 7, 2023 · 5 comments
Open
Labels
bug Something isn't working

Comments

@niniack
Copy link
Contributor

niniack commented Nov 7, 2023

Describe the bug
Calling

cl_strategy.train(
                        experience,
                        eval_streams=[val_exp],
                        num_workers=4,
                        collate_fn=my_custom_collate,
                    )

should respect all of the keyword arguments I pass in. In this case, my_custom_collate is discarded.

To Reproduce
For debugging, I define a custom strategy to examine what is passed into the dataloader. The make_train_dataloader function is lifted as it is from the 0.4.0 implementation

(Please note that I set breakpoints with pdb)

class CustomNaiveStrategy(Naive):
    def make_train_dataloader(
        self,
        num_workers=0,
        shuffle=True,
        pin_memory=None,
        persistent_workers=False,
        drop_last=False,
        **kwargs
    ):
        assert self.adapted_dataset is not None

        # fmt:off
        import pdb; pdb.set_trace();
        # fmt:on

        other_dataloader_args = self._obtain_common_dataloader_parameters(
            batch_size=self.train_mb_size,
            num_workers=num_workers,
            shuffle=shuffle,
            pin_memory=pin_memory,
            persistent_workers=persistent_workers,
            drop_last=drop_last,
        )

        # fmt:off
        pdb.set_trace()
        # fmt:on

        if "ffcv_args" in kwargs:
            other_dataloader_args["ffcv_args"] = kwargs["ffcv_args"]

        self.dataloader = TaskBalancedDataLoader(
            self.adapted_dataset, oversample_small_groups=True, **other_dataloader_args
        )

Expected behavior

other_dataloader_args should obey the kwargs and pass my_custom_collate along

Screenshots
bug

In the screenshot above p kwargs shows the custom collate function, but that does not show up in other_dataloader_args which is what is passed onto TaskBalancedDataLoader

Additional context

I cannot immediately think of why something like
other_dataloader_args.update(kwargs) is a poor idea, would love to hear thoughts.

@niniack niniack added the bug Something isn't working label Nov 7, 2023
@niniack niniack changed the title Make train dataloader discards my collate function make_train_dataloader discards custom collate function passed as kwarg Nov 7, 2023
@niniack
Copy link
Contributor Author

niniack commented Nov 7, 2023

This was supposedly fixed in #1089, or at least this is mentioned there:

Dataloading in strategies now checks if the dataset has a "collate_fn" function and uses that unless one is specified through kwargs (which takes precedence).

But, my experience above doesn't align with it. Either way, #1089 seems relevant to the conversation.

@AntonioCarta
Copy link
Collaborator

This is definitely a bug. Can you submit a PR that properly adds collate_fn to other_dataloader_args? This should be the only needed change.

@niniack
Copy link
Contributor Author

niniack commented Nov 8, 2023

Should this fix be done through updating _obtain_common_dataloader_parameters? Or is there another Avalanche style way of doing this? The hotfix of other_dataloader_args.update(kwargs) doesn't seem very Avalanche-y (but maybe i'm wrong!!)

I will also write a test to check whether kwarg collate takes precedence over dataset collate.

Feel free to assign to me, thanks

@AntonioCarta
Copy link
Collaborator

I think updating _obtain_common_dataloader_parameters is the best way.

@lrzpellegrini
Copy link
Collaborator

Was this fixed in the meantime?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
bug Something isn't working
Projects
None yet
Development

No branches or pull requests

3 participants