Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Cache sparse matrix in the operator and experiment with dataclass operators #1752

Draft
wants to merge 3 commits into
base: master
Choose a base branch
from

Conversation

inailuig
Copy link
Collaborator

@inailuig inailuig commented Mar 7, 2024

I've recently written a lot of operators which are dataclasses (because writing boilerplate tree_flatten/unflatten is annoying).

They are incompatible with vanilla netket as the current sparse matrix caching needs a hash, thus requiring some changes.

This moves the caching into the operators, in a way that is easy to override.
As an example I transformed IsingJax into a Pytree, to showcase how to do the caching for dataclass operators.

Comment on lines +279 to +290


class DiscreteJaxOperatorPytree(DiscreteJaxOperator, struct.Pytree):
_hilbert: DiscreteHilbert = struct.field(pytree_node=False)

@struct.property_cached(pytree_node=True)
def _sparse(self) -> JAXSparse:
return self._to_sparse()

@wraps(DiscreteJaxOperator.to_sparse)
def to_sparse(self) -> JAXSparse:
return self._sparse
Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having an extra class here is kind of annoying, but I don't think it can be avoided without reqiring all discretejaxoperators to be dataclasses, or copypasting this every time

@inailuig inailuig requested a review from PhilipVinc March 7, 2024 10:03
Copy link

codecov bot commented Mar 7, 2024

Codecov Report

Attention: Patch coverage is 45.45455% with 24 lines in your changes are missing coverage. Please review.

Project coverage is 51.11%. Comparing base (eb2be3d) to head (489c05b).

Files Patch % Lines
netket/operator/_discrete_operator_jax.py 50.00% 11 Missing ⚠️
netket/operator/_ising/jax.py 10.00% 9 Missing ⚠️
netket/operator/_discrete_operator.py 0.00% 2 Missing ⚠️
netket/operator/__init__.py 0.00% 1 Missing ⚠️
netket/vqs/full_summ/expect.py 66.66% 1 Missing ⚠️
Additional details and impacted files
@@             Coverage Diff             @@
##           master    #1752       +/-   ##
===========================================
- Coverage   82.98%   51.11%   -31.87%     
===========================================
  Files         303      303               
  Lines       18433    18445       +12     
  Branches     2718     2717        -1     
===========================================
- Hits        15296     9429     -5867     
- Misses       2465     8326     +5861     
- Partials      672      690       +18     

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

@PhilipVinc
Copy link
Member

They are incompatible with vanilla netket as the current sparse matrix caching needs a hash

I believe that our current operators play nice with lru_cache because they don't define a __hash__ method, so they are automatically hashable thanks to the default implementation (which is just their pointer address, but ok...)

The lowest effort solution to support this use case would be to hand-write an __hash__ method.
Alternatively, you could change whatever struct.dataclass implementation you're using to have some way to not define the __hash__ method.
This would should be simple, just a matter of piping the arguments down (probably needs setting hash=False and equality=Flase or something in data classes.dataclass() called from nk.utils.struct.dataclass.

This would be the simplest change that does not increase internal entropy of netket.

--

(because writing boilerplate tree_flatten/unflatten is annoying).

I agree that it is annoying. If there was an easy way to support this I would be happy.

However I feel this solution is complicated, and relies on several 'hidden assumptions' like that the pytree/class attributes have the same names as the fields that are set in the base constructors, which would be hard to document....

Dunno...

@inailuig
Copy link
Collaborator Author

inailuig commented Mar 7, 2024

I believe that our current operators play nice with lru_cache because they don't define a __hash__ method, so they are automatically hashable thanks to the default implementation (which is just their pointer address, but ok...)

Actually there are issues with the current one too.
lru_cache erraneously returns the cached value for some in-place modified operators (imul and consorts) when the hash of the obj does not change.
E.g. you use one hamiltonian for say exactsampler, then in-place modify it and use it again it will use the old matrix. This is what I tried to fix in the last two commits, but without much success...

@PhilipVinc
Copy link
Member

Right, you're indeed right.
I think the best solution to fix existing use cases would be to compute and store a cache a correct hash, and simply invalidate the cached value when needed.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

2 participants