Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Providing a costume collate_fn to DataLoader has no affect #9263

Open
mayabechlerspeicher opened this issue Apr 30, 2024 · 7 comments
Open
Labels

Comments

@mayabechlerspeicher
Copy link

馃悰 Describe the bug

Pyg DataLoader can receive a custom collate_fn as it extends the torch DataLoader, but in its constructor, it doesn't use the given collate_fn; instead, it always uses Collater.
I'm not sure if this is a bug or if the documentation is wrong, but the Pyg documentation states that any parameter used in torch's DataLoader can be used with Pyg's DataLoader. Still, this collate_fn parameter cannot be used.

So, to actually use a custom collate_fn, do I have to Extend DataLoader to use the given collate_fn?

Thanks.

Versions

2.5.3

@rusty1s
Copy link
Member

rusty1s commented May 2, 2024

Yes, PyG DataLoader is just a wrapper around torch.utils.data.DataLoader with a custom collate_fn. As such, this is the only argument that cannot be overridden. Let me clarify this in the documentation.

@mayabechlerspeicher
Copy link
Author

Thank you.
Could you please clarify why it should be restricted from being overridden?
I believe it is a typical case where a data object has custom keys that one wants to batch differently than concatenation (e.g., when their dimensions do not allow concatenation).

@rusty1s
Copy link
Member

rusty1s commented May 3, 2024

If we would allow overriding collate_fn in PyG's data loader, then this would mean it boils down to torch.utils.data.DataLoader. In this case, I don't see a good reason why you shouldn't use the vanilla PyTorch DataLoader in the first place.

Note that you can also customize concatenation by overriding Data.__cat_dim__ (see the advanced mini-batch tutorial in our documentation).

@mayabechlerspeicher
Copy link
Author

mayabechlerspeicher commented May 12, 2024

Thanks. Nonetheless, the standard DataLoader fails to add a dimension to the edge index as the edges are different sizes for different graphs.

So let's say I am not interested in the batching of the edge indexes in one huge graph, and I just want to wrap multiple graphs together, i.e., to stack the keys of the graphs in the batch, but the tensors of each key can be of different shapes (as in edge indexes). So the gradient computation will be done on the loss over the whole batch, but the forward pass will be done on each graph in the batch separately anyway (so GPU-wise it's not the most efficient it could be, but that's ok).
Because the tensors are not of the same dimensions, you cannot contact them, so Data.cat_dim would not help. what should I do in that case?

@rusty1s
Copy link
Member

rusty1s commented May 13, 2024

Do you mean you simply want to "batch" tensors together by stacking them in a list? I am not yet sure I understand, sorry.

@mayabechlerspeicher
Copy link
Author

Yes. So I have some costum keys in my Data object, that have different dimensions and I cannot stack them, I just want to put them in a list.

@rusty1s
Copy link
Member

rusty1s commented May 22, 2024

I see, that's indeed currently not possible. What we could do is to provide an option in Data to restrict concatenation of certain attributes. Would this work for your use-case?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

No branches or pull requests

2 participants