Replies: 3 comments 5 replies
-
Thanks for the question! While #17576 added the JAX-side plumbing, I believe the XLA compiler doesn't actually yet support this feature. @yashk2810 is that right? |
Beta Was this translation helpful? Give feedback.
-
Yes, that's correct. We are a couple of PRs away from enabling it for TPUs. So it should be landing soon enough. |
Beta Was this translation helpful? Give feedback.
-
Awesome!! This is my first time trying offloading, what are some ways to know it's working? I've implemented it and then viewed the HBM usage but I didn't see a difference (I don't think I'm using it right just yet, are there other ways to check?) |
Beta Was this translation helpful? Give feedback.
-
Hello! I am trying to optimize memory usage of a neural network. I read #17576 which describes that the activation offloading feature in remat.
I tried changing the
dot_with_no_batch_dims_saveable
:into
when doing
remat_call = jax.checkpoint(partial(model.__call__, train=True), policy=policy)
.I got a very similar OOM error with identical peak tpu hbm usage from the compiler. This seems counter-intuitive to me; offloading activations supposedly will reduce the peak ram usage. Am I missing something here? Does the compiler account for the memory saving from offloading? Or are there other compiler stuff going on? Thanks!
Beta Was this translation helpful? Give feedback.
All reactions