Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Better Handling of Nested Loop with Remat #20877

Open
LeoXinhaoLee opened this issue Apr 23, 2024 · 0 comments
Open

Better Handling of Nested Loop with Remat #20877

LeoXinhaoLee opened this issue Apr 23, 2024 · 0 comments
Labels
enhancement New feature or request

Comments

@LeoXinhaoLee
Copy link

Hi, the forward pass of our training includes passing through a nested loop with 2 layers (outer loop, inner loop).

The input data X of shape [T, F] is first reshaped into [outer_loop_num, inner_loop_num, F], and its first and second dimensions are being looped over.

Our goal is to use jax.checkpoint on each inner loop, such that all intermediate variables generated inside it won’t be saved.

We have identified two ways so far to do this, each with some potential problems. Please kindly check out this example colab: https://colab.research.google.com/drive/1TLLQbzIdSX1SYSujmGmdrPO28aa3ZyFb#scrollTo=m3aI--XBMINi

Thank you very much for your time and help!

@LeoXinhaoLee LeoXinhaoLee added the enhancement New feature or request label Apr 23, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement New feature or request
Projects
None yet
Development

No branches or pull requests

1 participant