-
Notifications
You must be signed in to change notification settings - Fork 81
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Custom behaviors plus jax leading to lookup in wrong spot #2603
Comments
I made some progress understanding what causes this to happen. Here is a significantly simplified reproducer: import awkward as ak
import numba
import numpy as np
behavior = {}
ak.jax.register_and_check()
USE_JAX = False # set to False to run this successfully
input_arr = ak.Array([1.0], backend=("jax" if USE_JAX else "cpu"))
@numba.vectorize(
[
numba.float32(numba.float32, numba.float32),
numba.float64(numba.float64, numba.float64),
]
)
def _some_kernel(x, y):
return x * x + y * y
@ak.mixin_class(behavior)
class SomeClass:
@property
def some_kernel(self):
return _some_kernel(self.x, self.y)
ak.behavior.update(behavior)
arr = ak.zip({"x": input_arr, "y": input_arr}, with_name="SomeClass")
arr.some_kernel # crashes with Jax This results in AttributeError: module 'jax.numpy' has no attribute '_some_kernel'
This error occurred while calling
numpy._some_kernel.__call__(
<Array [1.0] type='1 * float32'>
<Array [1.0] type='1 * float32'>
) The code runs successfully with |
Right - at the moment, users can't override ufuncs for JAX, so numba ufuncs throw exceptions. Numba functions wouldn't be differentiable via JAX; we'd need to substitute a JAX implementation. |
@Saransh-cpp, this is another one that you should self-assign (anything with label |
The coffea issue will be solved once their vector module is removed and scikit-hep/vector is recommended to the users - CoffeaTeam/coffea#874 (comment) For the issue on the awkward end, I am a bit confused regarding how we want the ideal behavior to look like -
Thanks! |
So with JAX's JIT-compilation off the table, the alternative of compiling in Numba is still there, but Numba does not propagate derivatives through its compiled code. Starting in January 2022 and (I was following it) until January 2023, @ludgerpaehler was trying to compile through Numba by using Enzyme, an autograd tool for LLVM code. I don't know the current state of that project, but that would allow us to connect JAX's non-JITted autograd with Numba's JITted autograd. Users already have to switch programming models between non-JIT and JIT, but in principle, it's possible to preserve derivatives across that boundary. |
Version of Awkward Array
ce63bf2
Description and code to reproduce
This is partner issue to CoffeaTeam/coffea#874 as perhaps this is more on the side of awkward than coffea. I am trying to combine custom behaviors (defined by coffea) with the jax backend of awkward. The reproducer below results in:
Reproducer:
Using the
"Momentum4D"
behavior fromvector
(aftervector.register_awkward()
) works. Skipping the backend conversion to jax also makes this work.Full trace
The text was updated successfully, but these errors were encountered: