Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Running stats with gradient checkpointing #1035

Open
vovaf709 opened this issue Jul 20, 2022 · 8 comments
Open

Running stats with gradient checkpointing #1035

vovaf709 opened this issue Jul 20, 2022 · 8 comments

Comments

@vovaf709
Copy link

vovaf709 commented Jul 20, 2022

According to patch_batchnorm source code if layer collecting running stats (e.g. BatchNorm) is checkpointed it will accumulate statistics only when grad is enabled (on backward pass). This induces inconsistency:

torch.manual_seed(1337)
seq = nn.Sequential(nn.Conv2d(4, 4, 3), nn.BatchNorm2d(4))
torch.manual_seed(1337)
seq_checkpointed = checkpoint_wrapper(nn.Sequential(nn.Conv2d(4, 4, 3), nn.BatchNorm2d(4)))

inp = torch.randn(2, 4, 16, 16)

seq(inp)
seq_checkpointed(inp)

seq[1].running_mean == seq_checkpointed[1].running_mean
tensor([False, False, False, False])

I think this behaviour should be modified to accumulate statistics at 1-st forward pass or at least mentioned in docs

@vovaf709 vovaf709 changed the title Runnings stats with gradient checkpointing Running stats with gradient checkpointing Jul 20, 2022
@min-xu-ai
Copy link
Contributor

Thanks for reporting. I slightly modified your code to demonstrate how it works:

import torch
from torch import nn
from fairscale.nn.checkpoint import checkpoint_wrapper

torch.manual_seed(1337)
seq = nn.Sequential(nn.Conv2d(4, 4, 3), nn.BatchNorm2d(4))
torch.manual_seed(1337)
seq_checkpointed = checkpoint_wrapper(nn.Sequential(nn.Conv2d(4, 4, 3), nn.BatchNorm2d(4)))

inp = torch.randn(2, 4, 16, 16).requires_grad_(True)

out = seq(inp)
out_ck = seq_checkpointed(inp)
torch.testing.assert_close(out, out_ck)

out.sum().backward()
out_ck.sum().backward()


print(seq[1].running_mean)
print(seq_checkpointed[1].running_mean)
torch.testing.assert_close(seq[1].running_mean, seq_checkpointed[1].running_mean)
#tensor([False, False, False, False])

As you can see, you need to run the backward pass to make the running_mean match. Just forward is not enough. Checkpoint_wrapper is used for training. Only doing the forward pass does not make sense IMHO. With backward pass, the stats is matching correctly.

@vovaf709
Copy link
Author

There is one kaggle trick - you can run multiple forward passes on test set to adapt running stats to it. This is weird case but still)

@min-xu-ai
Copy link
Contributor

Oh I see. That’s interesting! Do you have example code or pseudo code for it?

@min-xu-ai min-xu-ai reopened this Jul 21, 2022
@vovaf709
Copy link
Author

Code for this trick? If yes then it is as simple as

model.train()
for (X, y), _ in zip(test_loader, range(n_iter)):
    model(X)

@min-xu-ai
Copy link
Contributor

I see. Then after this loop you proceed with normal training for 1 epoch or the whole training N epochs?

@vovaf709
Copy link
Author

No, I run this loop on the test set (on which I want to get the highest target metric in competition) after the whole training. The idea is to adapt BN statistics to the test set which can have slightly different distribution

@vovaf709
Copy link
Author

I think I can come up with a solution on the next week, ok?

@nyngwang
Copy link

@vovaf709 ok

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants