New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for torch.mm with Sparse Half Tensors? "addmm_sparse_cuda" not implemented for Half #41069
Comments
Thanks for filing this issue, @sbonner0. You can perform the operation with a float32 tensor, of course, but short of that I think you'd have to write your own kernel or get one added to PyTorch, unfortunately. |
Hi @mruberry thanks so much for your reply! Do you think it would be challenging to code a kernel to do this and submit it to pytorch? I would be very happy to give it a go. |
Excellent question! I actually looked at this a bit and cuSPARSE does support this operation in half (https://docs.nvidia.com/cuda/cusparse/index.html). So you'd need to edit the dispatch here:
to include half types, then get a system using the newer cuSPARSE:
update the checks:
and instantiate the appropriate c10:Half template:
Then make sure cuSPARSE is being called properly and write a test in Python verifying the change works. If you have a machine with CUDA 10.2+ and are familiar with building projects like PyTorch and C++ it's definitely doable. |
Hi @mruberry thank you so much for this very detailed reply! Although it seems that you have largely done all the work - I can set aside some time at the end of next week to try and give this ago. |
Yep, it'd be great to get a PR implementing this! |
A note for anyone working on this (or future self): I have been fiddling a bit with this issue, and I have run into some really weird behaviour when using modding the existing code to support float16. I have seen the bug manifest itself in the form of noise in a single column (column 1 in my case) in the result matrix, while the remaining columns were correct. The noise would change, depending on the results in (at least some of) the remaining columns. Unfortunately, I am not free to share the code. |
CSR matrix - dense matrix multiplication is now supported for In [1]: import torch
In [2]: a = torch.randn(3,2).half().cuda()
...: i = torch.LongTensor([[0, 1, 1], [2, 0, 2]])
...: v = torch.FloatTensor([3, 4, 5])
...: b = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).half().cuda()
In [3]: b = b.to_sparse_csr()
In [4]: b @ a
Out[4]:
tensor([[-0.6729, -1.0430],
[-1.8916, 0.8125]], device='cuda:0', dtype=torch.float16) |
I used the concept listed above but it yielded the same error, as the original post. Is there any further documentation that describes half precision sparse MM? |
@cddavis93 I replicated @IvanYashchuk's result with torch-1.11.0+cu113 |
@IvanYashchuk, is cusparse used for the ``b @ a'' in
Where can I find the code for this part? |
Hi,
I am trying to perform sparse and dense matrix multiplication using half precision tensors in pytorch.
The following code:
will produce this error:
RuntimeError: "addmm_sparse_cuda" not implemented for 'Half'
Is there anyway to solve this?
Environment
-PyTorch version: 1.5.0
-Is debug build: No
-CUDA used to build PyTorch: 10.2
-OS: Arch Linux
-GCC version: (GCC) 10.1.0
-CMake version: version 3.17.3
-Python version: 3.8
-Is CUDA available: Yes
-CUDA runtime version: 10.2.89
-GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
-Nvidia driver version: 440.100
-cuDNN version: /usr/lib/libcudnn.so.7.6.5
cc @vincentqb @aocsa
The text was updated successfully, but these errors were encountered: