Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid inlining large arrays in JaxInMemoryRandomSampleIterator #285

Open
ethanluoyc opened this issue Jan 25, 2023 · 0 comments
Open

Avoid inlining large arrays in JaxInMemoryRandomSampleIterator #285

ethanluoyc opened this issue Jan 25, 2023 · 0 comments

Comments

@ethanluoyc
Copy link
Contributor

The JaxInMemoryRandomSampleIterator currently inlines the in-memory dataset. See
https://github.com/deepmind/acme/blob/master/acme/datasets/tfds.py#L199-L200

This causes some OOM issues due to some issues in XLA and also when running on GPU the process might hang. I have filed a more detailed issue in the JAX project google/jax#14080 and the authors recommend not inlining the array instead. I can create a PR if the developers would like to fix that.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

1 participant