You signed in with another tab or window. Reload to refresh your session.You signed out in another tab or window. Reload to refresh your session.You switched accounts on another tab or window. Reload to refresh your session.Dismiss alert
Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA.
However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).
Especially since we already have QDWH implemented in Nx.LinAlg.SVD.qdwh, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (like Nx.LinAlg.svd)
The text was updated successfully, but these errors were encountered:
@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.
SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.
Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA.
However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).
Especially since we already have QDWH implemented in
Nx.LinAlg.SVD.qdwh
, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (likeNx.LinAlg.svd
)The text was updated successfully, but these errors were encountered: