Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Re-implement Nx.LinAlg.eigh as defn #1027

Open
polvalente opened this issue Dec 16, 2022 · 3 comments
Open

Re-implement Nx.LinAlg.eigh as defn #1027

polvalente opened this issue Dec 16, 2022 · 3 comments
Labels
area:exla Applies to EXLA area:nx Applies to nx

Comments

@polvalente
Copy link
Contributor

Currently, we have a custom implementation for Nx.BinaryBackend and call the XLA implementation for eigh in EXLA.
However, the XLA implementation seems to suffer from similar issues to the SVD one, in which it ends up being slower and with a different accuracy from the one Jax uses (https://github.com/google/jax/blob/main/jax/_src/lax/eigh.py).

Especially since we already have QDWH implemented in Nx.LinAlg.SVD.qdwh, it seems like a good idea to also reimplement eigh as a defn with optional+custom_grad (like Nx.LinAlg.svd)

@polvalente
Copy link
Contributor Author

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

@christianjgreen
Copy link

Although #1424 did move eigh to defn, it's still worth looking into using a new implementation for speed and accuracy

Do you have a basic test I could use to compare them?

@polvalente
Copy link
Contributor Author

@christianjgreen in short, you can just compare the execution time of jax.linalg.eigh vs Nx.LinAlg.eigh using a 200x200 f32 tensor using EXLA as the default compiler and backend. You'll notice that Nx will barely handle it -- takes 42s on my machine -- while jax handles it just fine -- takes around 43ms on my machine.

SVD will consequently suffer with the same performance drop because SVD uses eigh under the hood.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
area:exla Applies to EXLA area:nx Applies to nx
Projects
None yet
Development

No branches or pull requests

2 participants