Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

use BCSR in sparse format of jaxoperators instead of BCOO #1804

Merged
merged 3 commits into from
May 17, 2024

Conversation

PhilipVinc
Copy link
Member

BCOO should just be used to construct it, not to use it for matrix vector multiplications...

@PhilipVinc PhilipVinc requested a review from wdphy16 May 13, 2024 13:28
@gcarleo
Copy link
Member

gcarleo commented May 13, 2024

any speedup?

@wdphy16
Copy link
Collaborator

wdphy16 commented May 14, 2024

Yes I agree. Although in MCState we don't need to multiply a NetKet operator on a dense vector, when working with FullSumState it indeed speeds up in my use case.

By the way, in FullSumState we can put the sparse matrix - dense vector multiplication into JIT to further speed it up, at the cost of longer compiling time (specializing on the sparsity structure can take a long time) and higher memory usage, as in my commit

or isinstance(other, jnp.ndarray)
or isinstance(other, JAXSparse)
):
if isinstance(other, JAXSparse):
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@wdphy16 the only questionable thing is that jax sparse matrices do not support multiplication by another jax CSR sparse matrix, so I had to ensure that when one tries to do jaxoperator @ sparsevector we convert it to dense first.

This breaks current behaviour where, as we use BCOO, we get a sparse vector out.

But I think that's not important? Hopefully one never does this kind of multiplication...

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah it's a bit unexpected for me, but anyway JAX sparse matrices are experimental so we can make breaking changes

In real use cases, I guess we usually first do NetKet operator multiplications then convert it to a JAX sparse matrix

@PhilipVinc
Copy link
Member Author

@wdphy16 how bad is compile time in that case?

@wdphy16
Copy link
Collaborator

wdphy16 commented May 14, 2024

@wdphy16 how bad is compile time in that case?

I don't remember exactly, but for large experiments it's much less than the actual computation time, and for small experiments I did not feel much more lag

@PhilipVinc PhilipVinc merged commit b841d6a into netket:master May 17, 2024
8 of 10 checks passed
@PhilipVinc
Copy link
Member Author

@wdphy16 as the improvement that you propose makes sense, would you care making a PR at some point? I think it's reasonable .

@PhilipVinc PhilipVinc deleted the pv/bscr branch May 17, 2024 17:27
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants