-
Notifications
You must be signed in to change notification settings - Fork 761
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
linear+gelu fused operator is not supported in ACL #1083
Comments
Hi @snadampal Thanks for raising this. We will discuss the feature request with the team. |
oneDNN shouldn't be falling back to the reference kernels (i.e. Also, do we know the relative importance of different activations and data types? I haven't done any in depth analysis but for compute bound activations like gelu or tanh, there may not be much benefit to fusing them over having a separate activation layer. For the simpler memory bound activations, there should be a larger benefit. I think non-leaky relu (α = 1) is already fused into quite a few kernels, although as far as I know, leaky relu is not yet. |
Hi @jondea , torch compiled version of I'm not sure if the gap in ACL is only the fused kernel or even the individual kernels. for fp32: for bf16 fast math mode: |
Great, thanks for the reproducer. It looks like ACL does in fact have a GELU implementation, at least for NEON FP32
It should be straightforward to hook this up here: and it will automatically get picked up by the I have made an internal issue to take a look at this and get back to you. Things are quite busy at the moment so I'll need to get back to you on timescales. |
@snadampal we now have a PR up for ACL GELU erf in oneDNN: oneapi-src/oneDNN#1843. This should enable ACL primitives (including inner product) to be used when there's a GELU erf post op. This isn't a fusion in the sense that the activation happens inside the GEMM kernel, but it does mean that you can make use of the ACL accelerated kernels when there are post ops in oneDNN. |
thanks for the note, @jondea , I will take a look at it. |
Hi @jondea , how about the fusion support for the other primitive and post-ops combinations? Could you please add support for matmul + post-ops like gelu/relu/erf/tanh as well? |
At the oneDNN level, we should automatically support combining matmul/conv/inner product with any binary or eltwise post op supported by the equivalent standalone ACL primitive. So matmul/conv/inner + gelu/relu/erf/tanh should accelerated by ACL in oneDNN (GELU went into v3.5). |
Output of 'strings libarm_compute.so | grep arm_compute_version':
arm_compute_version=v23.11 Build options: {'Werror': '0', 'debug': '0', 'neon': '1', 'opencl': '0', 'embed_kernels': '0', 'os': 'linux', 'arch': 'armv8a', 'build': 'native', 'multi_isa': '1', 'fixed_format_kernels': '1', 'openmp': '1', 'cppthreads': '0'} Git hash=b'add70ace1e57f65d1ae4d0cedaec6e4578cf87ff'
Platform:
AWS c7g.16xl
Operating System:
Ubuntu 22.04
Problem description:
PyTorch2.0 introduced torch.compile() for the neural network compilation. One of the important techniques the Graph compilation employs is the operator fusion. To execute those compiled graphs efficiently, the platform need to support the fused operators. For example, for Bert base model (I think any transformer model)
inner_product+relu
,matmul+relu
(or gelu or tanh) are commonly fused in the linear layer.The issue is ACL23.11 doesn't support the above mentioned operators, hence we are not able to take full advantage of the PyTorch Graph compilation optimizations on aarch64.
Steps to reproduce:
When you run the below script, you can see that the fused operators are falling back to onednn 'c' reference kernels because ACL doesn't support them.
pip3 install torch==2.1.1
export DNNL_VERBOSE=1
Note: On PyTorch main, I have disabled the operator fusion for aarch64 to be able to use at least the other optimizations from the compilation, here is the PR. So, please use PyTorch 2.1.1 to reproduce the issue.
The text was updated successfully, but these errors were encountered: