Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Extremely large gradient and vanishing images when using jax.precision=float32 #91

Open
ManfredStoiber opened this issue Sep 21, 2023 · 3 comments

Comments

@ManfredStoiber
Copy link

First of all, thank you very much for this impressive work!

When using the configuration jax.precision=float32 for training with images, I always get an extremely large gradient (model_opt_grad_norm at ~6e+8). I assume because of that, the openl image predictions become completely white.
When training the dmc_walker_walk task with the dmc_vision configurations by using the train.py-script, the image_loss_mean is at about 7e+7. When using other environments, the image_loss_mean starts at about 2000-5000, but the model_opt_grad_norm stays at ~6e+8.

I'm using float32 because I sporadically get NANs during training with images when using float16.

I already tried changing the lr and clipping values, as well as the image loss scale, but without success.

Am I maybe missing any other configurations I have to change when using float32?

Thank you for your help!
Best regards

image

@return-sleep
Copy link

May I ask if you have successfully trained the dreamerv3 agent. I'm curious what the final loss of each component looks like,such as image, reward or cont. When the reconstructed images are very similar, I find that the reward's prediction is not as good as it should be. I'm not sure if it also affects subsequent strategy training. Thanks for sharing your thoughts.

@ManfredStoiber
Copy link
Author

Unfortunately not, at least not when training on images in the walker environment

@danijar
Copy link
Owner

danijar commented Apr 19, 2024

Walker always worked for me from images, regardless of precision. I've just updated the paper and code, which has a better optimizer now. Curious if this is still an issue on your end.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants