Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

limit the # of threads for jax #127

Open
h3jia opened this issue Jun 1, 2022 · 2 comments
Open

limit the # of threads for jax #127

h3jia opened this issue Jun 1, 2022 · 2 comments

Comments

@h3jia
Copy link

h3jia commented Jun 1, 2022

Hi, I know this is probably more of an issue on the jax side and has been discussed there, e.g. google/jax#743, google/jax#1539 and google/jax#6790, although I'm still wondering if you know how to limit the # of threads for jax. Below is a simple snippet showing that currently, jax does not observe the threadpool limits.

import jax.numpy as jnp
from threadpoolctl import threadpool_limits

ja = jnp.ones((1000, 1000))
with threadpool_limits(5):
    for _ in range(100):
        foo = ja @ ja
@jeremiedbb
Copy link
Collaborator

Hi @HerculesJack, according to this comment google/jax#743 (comment), the threading mechanism of jax is not one of the ones that threadpoolctl supports. It could be interesting to check if Eigen threadpools exposes some symbols allowing to control the number of threads.

@ogrisel
Copy link
Contributor

ogrisel commented Jul 11, 2023

Note: if Eigen exposes some well defined symbols to inspect and control the number of threads in its threadpool, then the mechanism implemented in #137 should make it possible to add support for tensorflow and jax to threadpoolctl.

I tried to check the Eigen documentation to see if it's the case but it seems to be down at the moment: https://www.tuxfamily.org/en/news/2023070900.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants