Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can we use Adam or other optimizer instead of SGD to train the network? #6

Closed
ShadowLau opened this issue Aug 29, 2018 · 3 comments
Closed

Comments

@ShadowLau
Copy link

Hi, I use swa to train my network recently, and the task is Re-ID. But I can not see obvious improvement (actually, almost the same w/o swa) when training network with Adam.

So, can we use Adam or other optimizer instead of SGD to train the networks, if we want to improve our networks with swa?

@izmailovpavel
Copy link
Collaborator

Hi, sorry for delayed response. In my experience SWA works best with SGD. Adam sets the learning rates adaptively, which is not ideal for SWA. However, we did see some improvement with other optimizers as well. I recommend trying to tune the learning rate schedule (try increasing the learning rates during the SWA stage), or maybe switching to SGD for the SWA stage.

@mrgloom
Copy link

mrgloom commented Oct 30, 2020

As I see in TensorFlow Adam have trainable parameters, so the question is should we exclude these parameters from averaging? Same question for BN trainable parameters.

@izmailovpavel
Copy link
Collaborator

Hey @mrgloom. The adam parameters and BN parameters are not trainable parameters of the network. In fact, the former are tensors stored in the optimizer state, and the latter are buffers of the model. They should not be averaged. However, you need to fix the batchnorm statistics for the SWA model in the end of training (https://pytorch.org/blog/stochastic-weight-averaging-in-pytorch/#batch-normalization)

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants