New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
nt.linearize induces a CUDA_ERROR_OUT_OF_MEMORY error #144
Comments
@romanngg , you were really helpful previously - any thoughts here? Thanks in advance :) |
|
@romanngg thanks for getting back to me so soon! I'll check the max memory consumption but I don't think that's the reason because we could successfully "manually" perform a forward and backward pass of
I suspect that there might be an odd interaction between |
Edit: Ignore that last observation. That problem vanished when we reduced the per-GPU batch size.. |
Here's a self-contained colab that reproduces the issue. https://colab.research.google.com/drive/184moQLq3tjo-wEpc8gD7fXCFguAVDBOm#scrollTo=k4CjYqp5qLvj We suspect that |
Another insight: GPU on Colab breaks, but TPU on Colab is fine |
Looks like someone else has a similar problem while using neural tangents, also potentially arising from pmap |
We're trying to fine-tune a linearized Vision Transformer by adapting code from https://github.com/google-research/vision_transformer/blob/main/vit_jax.ipynb.
We're running into a really puzzling problem: when we load a model, we can train it, and when we linearize it, we can still train the pre-linearized model. However, when we try using the linearized model, we get:
This error emerges regardless of whether we are using 1 GPU or multiple. It also emerges whether we are using a large batch (512) or small (1).
We manually tested that a forward pass raises no error, and that a backward pass raises no error. We suspect that the error might arise from the following code (although we could be wrong!):
Their code:
That function is then called via:
The training loop where the memory error arises:
The above code is all copied from the ViT repo. This is how we linearize the ViT model:
The text was updated successfully, but these errors were encountered: