Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Enable A^T GEMM for BF16 #781

Open
itaraban opened this issue Jun 14, 2023 · 3 comments
Open

Enable A^T GEMM for BF16 #781

itaraban opened this issue Jun 14, 2023 · 3 comments

Comments

@itaraban
Copy link

Could you please enable such case in GEMM for BF16?
This is related to DGL project optimizations

@hfp
Copy link
Collaborator

hfp commented Jun 14, 2023

Is there a source file in DGL referring to what's needed or a model/case that runs into this gap?

@itaraban
Copy link
Author

In RCGN model DGL use segment_mm operation, which is currently implemented via torch(so it is super slow for cpu right now),
I prepared branch with LIBXSMM implementation - itaraban/dgl@cc03905
Which use same algorithm as CUDA version - https://github.com/dmlc/dgl/blob/master/src/array/cuda/gather_mm.cu#L202

It works pretty well for float and double, performance gain is up to 3x times for full model.
But I cannot use it for BF16, I got https://github.com/libxsmm/libxsmm/blob/main_stable/src/generator_gemm.c#L344.
In such case we will still use torch for BF16 and model will be 2 times slower than float version .

@hfp
Copy link
Collaborator

hfp commented Jun 14, 2023

I checked, we do not have Bf16 TN-case or at least not tested/exercised, i.e., this seems to be a valid issue (beside of TN/A-transpose being an unfortunate case in the hot-path ;-).

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants