Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

How to exclude some operations in checkpoint wrapper? #1014

Open
kenmbkr opened this issue Jun 22, 2022 · 1 comment
Open

How to exclude some operations in checkpoint wrapper? #1014

kenmbkr opened this issue Jun 22, 2022 · 1 comment

Comments

@kenmbkr
Copy link

kenmbkr commented Jun 22, 2022

I would like to checkpoint my module that takes the result of a checkpointed module (cond in the example below) as input.

class Test(nn.Module):
  def __init__(self):
    super(Test, self).__init__()

  def forward(x, cond=None):
    if cond is not None:
      # do something

    return result

The above module works fine when checkpointed and cond is None. However, when cond is not None, I am getting the following error.

RuntimeError: Trying to backward through the graph a second time (or directly access saved variables after they have already been freed).
Saved intermediate values of the graph are freed when you call .backward() or autograd.grad().
Specify retain_graph=True if you need to backward through the graph a second time or if you need to access saved variables after calling backward.

I have tried the following context managers but none of them works. This is a large module and I prefer to wrap it while ignoring the statements that cannot be checkpointed. Please kindly advise any workarounds. Thank you.

from fairscale.nn.checkpoint.checkpoint_activations import disable_checkpointing

with disable_checkpointing():
  if cond is not None:
    # do something
from fairscale.nn.checkpoint.checkpoint_activations import enable_recomputing

with enable_recomputing():
  if cond is not None:
    # do something
@min-xu-ai
Copy link
Contributor

Thanks for reporting this.

Do you have a complete test script that demonstrate this issue?
It is possible that some tensor needs to be detach()'ed in this case. It is not something we have explicit tested.
Have you also tried pytorch's native checkpointing module or raise an issue with pytorch folks?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants