-
Notifications
You must be signed in to change notification settings - Fork 174
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Cache sparse matrix in the operator and experiment with dataclass operators #1752
base: master
Are you sure you want to change the base?
Conversation
|
||
|
||
class DiscreteJaxOperatorPytree(DiscreteJaxOperator, struct.Pytree): | ||
_hilbert: DiscreteHilbert = struct.field(pytree_node=False) | ||
|
||
@struct.property_cached(pytree_node=True) | ||
def _sparse(self) -> JAXSparse: | ||
return self._to_sparse() | ||
|
||
@wraps(DiscreteJaxOperator.to_sparse) | ||
def to_sparse(self) -> JAXSparse: | ||
return self._sparse |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Having an extra class here is kind of annoying, but I don't think it can be avoided without reqiring all discretejaxoperators to be dataclasses, or copypasting this every time
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## master #1752 +/- ##
===========================================
- Coverage 82.98% 51.11% -31.87%
===========================================
Files 303 303
Lines 18433 18445 +12
Branches 2718 2717 -1
===========================================
- Hits 15296 9429 -5867
- Misses 2465 8326 +5861
- Partials 672 690 +18 ☔ View full report in Codecov by Sentry. |
I believe that our current operators play nice with The lowest effort solution to support this use case would be to hand-write an This would be the simplest change that does not increase internal entropy of netket. --
I agree that it is annoying. If there was an easy way to support this I would be happy. However I feel this solution is complicated, and relies on several 'hidden assumptions' like that the pytree/class attributes have the same names as the fields that are set in the base constructors, which would be hard to document.... Dunno... |
Actually there are issues with the current one too. |
Right, you're indeed right. |
I've recently written a lot of operators which are dataclasses (because writing boilerplate tree_flatten/unflatten is annoying).
They are incompatible with vanilla netket as the current sparse matrix caching needs a hash, thus requiring some changes.
This moves the caching into the operators, in a way that is easy to override.
As an example I transformed IsingJax into a Pytree, to showcase how to do the caching for dataclass operators.