-
Notifications
You must be signed in to change notification settings - Fork 957
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
rfcs: multi-head attention #1745
base: rfcs
Are you sure you want to change the base?
Conversation
39ef34a
to
6621ab5
Compare
6621ab5
to
9fd36db
Compare
Hi @igorsafo, since GC MHA patterns have a lot of variation, what would such an API look like? Thanks! |
Since there is no primitive API planned there is no need for an additional API. Frameworks will build graphs using oneDNN Graph API and it will match the SDPA patterns and will create GC or oneDNN primitive-based implementation as a partition. |
Thanks for the info, @igorsafo! PyTorch graph path has now moved to matching patterns on the framework side. So, this approach requires oneDNN Graph patterns to be hard-coded, and a PyTorch op can be called to compile/execute partitions. Looks like this RFC wouldn't change our approach. |
This approach might result in a lot of issues in the future, so I would highly recommend to discuss it further with the PyTorch maintainers before we put a lot of effort on Graph API integration. Since Graph patterns contain many more ops (vs primitive + post_op) any change within a pattern will result in a miss. Also if we add an optimization to a new pattern this will not be catched by the framework pattern matcher until it is updated explicitely in the framework codebase. The main benefit of Graph API is that an application/framework can throw a graph into it and oneDNN will return optimized partitions which is not possible with oneDNN primitives. In primitive API user has to know what primitives and post ops are supported. A pattern matcher on the framework side removes this benefit. |
Hi @igorsafo, I agree that such an approach requires more work on the framework side since oneDNN Graph partitions need to be hardcoded in the framework with such an implementation. Meta would like to follow the approach of letting the framework decide which pattern to offload to another library, which runs contrary to oneDNN Graph's ease-of-use principles that allow full graph to be passed to oneDNN Graph, which can then ascertain which patterns it can support. |
Description
A link to the rendered document: link
Fixes # (github issue)
Checklist
General
make test
andmake test_benchdnn_*
) pass locally for each commit?Performance improvements
New features
Bug fixes
RFC PR