You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Hi, the forward pass of our training includes passing through a nested loop with 2 layers (outer loop, inner loop).
The input data X of shape [T, F] is first reshaped into [outer_loop_num, inner_loop_num, F], and its first and second dimensions are being looped over.
Our goal is to use jax.checkpoint on each inner loop, such that all intermediate variables generated inside it won’t be saved.
Hi, the forward pass of our training includes passing through a nested loop with 2 layers (outer loop, inner loop).
The input data X of shape [T, F] is first reshaped into
[outer_loop_num, inner_loop_num, F]
, and its first and second dimensions are being looped over.Our goal is to use
jax.checkpoint
on each inner loop, such that all intermediate variables generated inside it won’t be saved.We have identified two ways so far to do this, each with some potential problems. Please kindly check out this example colab: https://colab.research.google.com/drive/1TLLQbzIdSX1SYSujmGmdrPO28aa3ZyFb#scrollTo=m3aI--XBMINi
Thank you very much for your time and help!
The text was updated successfully, but these errors were encountered: