New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
[torch] Add Canonicalize Pattern for embedding op #3277
base: main
Are you sure you want to change the base?
Conversation
Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.
I can see benefit to this optimization however this is more avoiding the compilation issue we have been encountering rather than preventing the crash. |
Note this could also be a pessimization: if you have your embeddings as f32, gather them, and convert to f16 you really want the conversion to fold into the embeddings so you aren't shipping (and doing the memory transactions) on f32 if you don't need those bits. This may get taken care of later in the pipeline but it's important to note that there are some massive implications of things like this (it's always better to hoist narrowing operations and sink widening operations, almost never the opposite). |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.
There's a tradeoff between memory and compute; doing this might take more memory but is less compute-intensive, whereas the one suggested might be compute-intensive since we are not able to fuse both kernels at the backend. I will add the check to perform a swap only during the widening case. |
Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.
Issue: iree-org/iree#17226 (comment)