New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Tenalg einsum backend: cache contraction equation #459
base: main
Are you sure you want to change the base?
Conversation
Move unfolding_dot_khatri_rao to tenalg, add einsum version
Codecov Report
@@ Coverage Diff @@
## main #459 +/- ##
==========================================
- Coverage 86.84% 86.39% -0.46%
==========================================
Files 118 119 +1
Lines 7313 7357 +44
==========================================
+ Hits 6351 6356 +5
- Misses 962 1001 +39
📣 We’re building smart automated test selection to slash your CI/CD build times. Learn more |
Thanks for adding this @JeanKossaifi, I will try to look at it asap but that's slightly outside my expertise. |
try: | ||
equation = cache[key] | ||
except KeyError: | ||
equation = fun(*args, **kwargs) | ||
cache[key] = equation |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Is there a reason why you use try
-except
instead of checking if the key is in the cache beforehand? Is that more efficient?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yes, getting an item is just O(1) while checking for existence is O(n) and the overhead of the try/except is minimal and only happens at the first call so the overall cost is just O(1).
@einsum_path_cached | ||
def inner_path(tensor1, tensor2, n_modes=None): |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Will that not set key=tensor1
? will that give the correct behaviour for the cache?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
No the key is only used inside the wrapper to retrieve the cached version but not actually passed to the wrapped function. I need to document that clearly.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I did a quick pass on some of the files (nothing big, didn't download and run the code) :)
Thanks @yngvem, great points, as always! :) |
Co-authored-by: Yngve Mardal Moe <yngve.m.moe@gmail.com>
Co-authored-by: Yngve Mardal Moe <yngve.m.moe@gmail.com>
So update, I don't actually see any speedup by caching the einsum equation, not sure whether we want to merge this feature anyway. Any thoughts? |
So if I understood the PR correctly, what you are doing is storing the contracting path in a global dictionary. While this could be interesting for further improvements, this PR alone will not lead to significant speed-up; what is costly is not to compute the einsum path, but to actually compute the contractions along that path. Therefore what we would need to cache is rather the intermediate results of the einsum. But I do not believe we have access to these partial results easily, so maybe indeed this PR is not useful as such (but the code structure could be useful to start another PR where we cache contraction results along a path? Then we will also need to store the path so we would reuse this code). No particular comment on the code itself, maybe the doc is not so clear so it took me some time to understand what was going on :p edit: In fact this is useful to impose a specific path, such as done in #462 with einsum-opt; so it could also lead to speed ups this way I guess. Would be curious to see your tests, maybe you do not see a speedup because the naive path is already close to optimal in your experiments? |
@cohenjer I was talking about caching the contraction equation here, not the contraction path. The way I wrote the einsum tenalg backend is that I first check the validity of the operation (e.g. mttkrp) and then generate programmatically the corresponding contraction equation. The idea of this PR is that it is redundant to perform these checks and generate the equation at each call. However the gain seems to be pretty much non-existent so this may be over engineering. Caching the optimal contraction path on the other hand (#462) always helps. |
Adds a wrapper to cache the contraction equation and just reuse it.