Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Model Initialization Extremely slow #54

Open
dc250601 opened this issue May 21, 2023 · 2 comments
Open

Model Initialization Extremely slow #54

dc250601 opened this issue May 21, 2023 · 2 comments

Comments

@dc250601
Copy link

Is there a way to speed up the model initialisation process? Every time I initialize the model, it takes over 30 minutes to initialize the model before the training starts.

@Gabri95
Copy link
Collaborator

Gabri95 commented Jun 22, 2023

Hi @dc250601

Unfortunately, this can happen for wide models.
This is due to the slow computation of the variance needed for He weight initialization here.

To speed this up, these variances can be cached such that following layers using the same basisexpansion / basissampler / basismanager will not need to recompute them.
This also helps if you train your model multiple times in a row (only the first time these variances need to be computed).

The R3Conv (and R2Conv) constructor calls this method with cached=False by default, so no caching is performed.
However, you can set initialize=False to avoid initialization entirely and then manually use generalized_he_init with cached=True.

Alternatively, you could also try to use the delta-orthogonal initialization, which I think is a bit faster.
As earlier, you'll have to disable the automatic initialization within the conv layers by using initialize=False and then manually call this initialization method.

Let me know if these solutions work for you!

Best,
Gabriele

@jacksonloper
Copy link

jacksonloper commented Dec 3, 2023

So the He initialization is why things are slow if I have a convolution with lots of input channels and output channels? And I suppose that would be true even if I have many duplicates of the same "kind" of channels (i.e. 128 irrep(5) channels as input and 128 irrep(5) channels as output)?

Or is there some other kind of "width" that would explain the slowness? Like maybe by width do you mean kernel size? Basically my question is: what do you mean by "wide" model?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants