You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version jax-metal 0.0.6 using pip as detailed here.
The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the lax.lgamma function is not supported under the hood on M1 metal.
I would categorize this as essential functionality for a wide class of probabilistic machine learning models so I am filing this as a bug. Note that following functions (among others) are broken as a result:
jax.scipy.stats.binom.logpmf
jax.scipy.stats.nbinom.logpmf
jax.scipy.stats.poisson.logpmf
jax.scipy.stats.dirichlet.logpdf
jax.scipy.stats.beta.logpdf
jax.scipy.stats.gamma.logpdf
...
>>> jax.scipy.stats.binom.logpmf(1, n=2, p=0.5)
jax.errors.SimplifiedTraceback: For simplicity, JAX has removed its internal frames from the traceback of the following exception. Set JAX_TRACEBACK_FILTERING=off to include these.
The above exception was the direct cause of the following exception:
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/stats/binom.py", line 31, in logpmf
gammaln(n + 1),
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/scipy/special.py", line 44, in gammaln
return lax.lgamma(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/lax/special.py", line 46, in lgamma
return lgamma_p.bind(x)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 422, in bind
return self.bind_with_trace(find_top_trace(args), args, params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 425, in bind_with_trace
out = trace.process_primitive(self, map(trace.full_raise, args), params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/core.py", line 913, in process_primitive
return primitive.impl(*tracers, **params)
File "/Users/ljb80/.virtualenvs/jax-metal/lib/python3.10/site-packages/jax/_src/dispatch.py", line 87, in apply_primitive
outs = fun(*args)
jaxlib.xla_extension.XlaRuntimeError: UNKNOWN: <stdin>:1:0: error: failed to legalize operation 'chlo.lgamma'
<stdin>:1:0: note: see current operation: %0 = "chlo.lgamma"(%arg0) : (tensor<f32>) -> tensor<f32>
System info (python version, jaxlib version, accelerator, etc.)
Description
I am attempting to use JAX on Metal (on a M1 Pro chip) to model discrete (count) data. I've installed the latest version
jax-metal 0.0.6
using pip as detailed here.The installation seems to have worked overall as I can perform basic Jax array operations on GPU. However, when I try to compute the (log-)PMFs/PDFs of random variables which are defined in terms of the (log-)Gamma function I get errors like the one below which seems to indicate that the
lax.lgamma
function is not supported under the hood on M1 metal.I would categorize this as essential functionality for a wide class of probabilistic machine learning models so I am filing this as a bug. Note that following functions (among others) are broken as a result:
...
System info (python version, jaxlib version, accelerator, etc.)
The text was updated successfully, but these errors were encountered: