Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Probability density/mass functions from jax.scipy.stats not supported on Metal #20835

Open
logan-blaine opened this issue Apr 19, 2024 · 1 comment
Assignees
Labels
Apple GPU (Metal) plugin bug Something isn't working

Comments

@logan-blaine
Copy link

Description

I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip as detailed here.

The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal.

I would categorize this as essential functionality for a wide class of probabilistic machine learning models so I am filing this as a bug. Note that following functions (among others) are broken as a result:

  • jax.scipy.stats.binom.logpmf
  • jax.scipy.stats.nbinom.logpmf
  • jax.scipy.stats.poisson.logpmf
  • jax.scipy.stats.dirichlet.logpdf
  • jax.scipy.stats.beta.logpdf
  • jax.scipy.stats.gamma.logpdf
    ...
>>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5)

jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.

The above exception was the direct cause of the following exception:

Traceback (most recent call last):
  File "<stdin>", line 1, in <module>
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf
    gammaln(n + 1),
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln
    return lax.lgamma(x)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma
    return lgamma_p.bind(x)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
    return self.bind_with_trace(find_top_trace(args), args, params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
    out = trace.process_primitive(self, map(trace.full_raise, args), params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
    return primitive.impl(*tracers, **params)
  File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
    outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma'
<stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>

System info (python version, jaxlib version, accelerator, etc.)

jax:    0.4.26
jaxlib: 0.4.23
numpy:  1.26.4
python: 3.10.6 | packaged by conda-forge | (main, Aug 22 2022, 20:38:29) [Clang 13.0.1 ]
jax.devices (1 total, 1 local): [METAL(id=0)]
process_count: 1
platform: uname_result(system='Darwin', node='PHS027794', release='23.4.0', version='Darwin Kernel Version 23.4.0: Fri Mar 15 00:10:42 PDT 2024; root:xnu-10063.101.17~1/RELEASE_ARM64_T6000', machine='arm64')
@logan-blaine logan-blaine added the bug Something isn't working label Apr 19, 2024
@shuhand0
Copy link
Collaborator

chlo.lgmma is not supported. We will look into add the support.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Apple GPU (Metal) plugin bug Something isn't working
Projects
None yet
Development

No branches or pull requests

4 participants