-
I am curious about how to do this in jax efficiently. Assume that I have a a100-80g gpu.
I encountered the OOM error. Do I have to avoid it? For more context, I tried to update param and optimizer state with LLaMA3 8B. Assume that I use bf16, the parameter takes 16G, the adam optimizer state takes 32G, the gradient takes 16G, which is 64G total. Now, I update the optimizer state, which introduce additional 32G and it makes OOM error. Someone might think that jax.jit will in-place operation after transformation, however it happens as the assignment to updated state is outside the jit function.
|
Beta Was this translation helpful? Give feedback.
Replies: 2 comments
-
I think donate_argnum would be a solution. I will try it. |
Beta Was this translation helpful? Give feedback.
-
JAX arrays are immutable by design, so there is no way outside a JIT context to modify an array in-place. That said, if you are doing your operation fully within the context of a JIT-compiled operation (i.e you never materialize the updated array), the compiler will perform updates in-place when the semantics of the original program allow it. |
Beta Was this translation helpful? Give feedback.
#19125
I think donate_argnum would be a solution. I will try it.