Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[torch] Add Canonicalize Pattern for embedding op #3277

Open
wants to merge 1 commit into
base: main
Choose a base branch
from

Conversation

pashu123
Copy link
Member

@pashu123 pashu123 commented May 2, 2024

Converts PrimConvertOp followed by Embedding -> Embedding followed by PrimConvertOp. We don't need to cast the entire matrix; just the output of the embedding op.

Issue: iree-org/iree#17226 (comment)

Converts PrimConvertOp followed by Embedding -> Embedding followed by
PrimConvertOp. We don't need to cast the entire matrix; just the output
of the embedding op.
@rsuderman
Copy link
Contributor

I can see benefit to this optimization however this is more avoiding the compilation issue we have been encountering rather than preventing the crash.

@benvanik
Copy link
Contributor

benvanik commented May 7, 2024

Note this could also be a pessimization: if you have your embeddings as f32, gather them, and convert to f16 you really want the conversion to fold into the embeddings so you aren't shipping (and doing the memory transactions) on f32 if you don't need those bits. This may get taken care of later in the pipeline but it's important to note that there are some massive implications of things like this (it's always better to hoist narrowing operations and sink widening operations, almost never the opposite).

@rsuderman rsuderman self-requested a review May 7, 2024 22:03
Copy link
Contributor

@rsuderman rsuderman left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.

@pashu123
Copy link
Member Author

pashu123 commented May 8, 2024

To avoid hurting performance we should only perform the swap during the widening case, otherwise we are potentially loading more data just to truncate back down whereas there is benefit to truncating overall.

There's a tradeoff between memory and compute; doing this might take more memory but is less compute-intensive, whereas the one suggested might be compute-intensive since we are not able to fuse both kernels at the backend. I will add the check to perform a swap only during the widening case.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants