-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
use BCSR in sparse format of jaxoperators instead of BCOO #1804
Conversation
any speedup? |
Yes I agree. Although in By the way, in |
or isinstance(other, jnp.ndarray) | ||
or isinstance(other, JAXSparse) | ||
): | ||
if isinstance(other, JAXSparse): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@wdphy16 the only questionable thing is that jax sparse matrices do not support multiplication by another jax CSR sparse matrix, so I had to ensure that when one tries to do jaxoperator @ sparsevector
we convert it to dense first.
This breaks current behaviour where, as we use BCOO, we get a sparse vector out.
But I think that's not important? Hopefully one never does this kind of multiplication...
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Ah it's a bit unexpected for me, but anyway JAX sparse matrices are experimental so we can make breaking changes
In real use cases, I guess we usually first do NetKet operator multiplications then convert it to a JAX sparse matrix
@wdphy16 how bad is compile time in that case? |
I don't remember exactly, but for large experiments it's much less than the actual computation time, and for small experiments I did not feel much more lag |
@wdphy16 as the improvement that you propose makes sense, would you care making a PR at some point? I think it's reasonable . |
BCOO should just be used to construct it, not to use it for matrix vector multiplications...