Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for torch.mm with Sparse Half Tensors? "addmm_sparse_cuda" not implemented for Half #41069

Closed
sbonner0 opened this issue Jul 7, 2020 · 10 comments
Labels
feature A request for a proper, new feature. module: half Related to float16 half-precision floats module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@sbonner0
Copy link

sbonner0 commented Jul 7, 2020

Hi,

I am trying to perform sparse and dense matrix multiplication using half precision tensors in pytorch.

The following code:

import torch
a = torch.randn(3,2).half().cuda()
i = torch.LongTensor([[0, 1, 1],  [2, 0, 2]]) 
v = torch.FloatTensor([3, 4, 5]) 
b = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).half().cuda()
c = torch.spmm(b, a)

will produce this error: RuntimeError: "addmm_sparse_cuda" not implemented for 'Half'

Is there anyway to solve this?

Environment

-PyTorch version: 1.5.0
-Is debug build: No
-CUDA used to build PyTorch: 10.2

-OS: Arch Linux
-GCC version: (GCC) 10.1.0
-CMake version: version 3.17.3

-Python version: 3.8
-Is CUDA available: Yes
-CUDA runtime version: 10.2.89
-GPU models and configuration: GPU 0: GeForce GTX 1080 Ti
-Nvidia driver version: 440.100
-cuDNN version: /usr/lib/libcudnn.so.7.6.5

cc @vincentqb @aocsa

@mruberry mruberry added module: half Related to float16 half-precision floats module: sparse Related to torch.sparse feature A request for a proper, new feature. triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Jul 7, 2020
@mruberry
Copy link
Collaborator

mruberry commented Jul 8, 2020

Thanks for filing this issue, @sbonner0. You can perform the operation with a float32 tensor, of course, but short of that I think you'd have to write your own kernel or get one added to PyTorch, unfortunately.

@sbonner0
Copy link
Author

sbonner0 commented Jul 8, 2020

Hi @mruberry thanks so much for your reply! Do you think it would be challenging to code a kernel to do this and submit it to pytorch? I would be very happy to give it a go.

@mruberry
Copy link
Collaborator

Hi @mruberry thanks so much for your reply! Do you think it would be challenging to code a kernel to do this and submit it to pytorch? I would be very happy to give it a go.

Excellent question! I actually looked at this a bit and cuSPARSE does support this operation in half (https://docs.nvidia.com/cuda/cusparse/index.html). So you'd need to edit the dispatch here:

to include half types, then get a system using the newer cuSPARSE:

#if !defined(_MSC_VER) && defined(__CUDACC__) && CUSPARSE_VERSION >= 10301 // CUDA release >= 10.2 and not windows

update the checks:

static_assert(std::is_same<float, T>::value || std::is_same<double, T>::value, "csrmm2 only supports float and double value types");

and instantiate the appropriate c10:Half template:

template void csrmm2<float>(

Then make sure cuSPARSE is being called properly and write a test in Python verifying the change works.

If you have a machine with CUDA 10.2+ and are familiar with building projects like PyTorch and C++ it's definitely doable.

@sbonner0
Copy link
Author

Hi @mruberry thank you so much for this very detailed reply! Although it seems that you have largely done all the work - I can set aside some time at the end of next week to try and give this ago.
Seeing that this seems to require CUDA 10.2, if I was able to get this to work, should it be something that I should consider as a pull request?

@mruberry
Copy link
Collaborator

Yep, it'd be great to get a PR implementing this!

@sorenrasmussenai
Copy link
Contributor

sorenrasmussenai commented May 28, 2021

A note for anyone working on this (or future self):

I have been fiddling a bit with this issue, and I have run into some really weird behaviour when using modding the existing code to support float16.
I have not had the time to isolate the problem, but I believe it is due to a bug in CuSparse, specifically the algorithm CUSPARSE_SPMM_CSR_ALG1 with CUDA_R_16F. Changing the algorithm to CUSPARSE_SPMM_CSR_ALG2 makes the problem go away. Note that CUSPARSE_SPMM_CSR_ALG2 is non-deterministic, which may be a deal-breaker..

I have seen the bug manifest itself in the form of noise in a single column (column 1 in my case) in the result matrix, while the remaining columns were correct. The noise would change, depending on the results in (at least some of) the remaining columns. Unfortunately, I am not free to share the code.

@pearu pearu added this to To do in Sparse tensors Aug 10, 2021
@IvanYashchuk
Copy link
Collaborator

CSR matrix - dense matrix multiplication is now supported for float16:

In [1]: import torch

In [2]: a = torch.randn(3,2).half().cuda()
   ...: i = torch.LongTensor([[0, 1, 1],  [2, 0, 2]])
   ...: v = torch.FloatTensor([3, 4, 5])
   ...: b = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).half().cuda()

In [3]: b = b.to_sparse_csr()

In [4]: b @ a
Out[4]:
tensor([[-0.6729, -1.0430],
        [-1.8916,  0.8125]], device='cuda:0', dtype=torch.float16)

Sparse tensors automation moved this from To do to Done Jan 6, 2022
@cddavis93
Copy link

I used the concept listed above but it yielded the same error, as the original post.
I attempted 'to_sparse_csr()' as well as 'to_sparse'
My pytorch version is 1.11.0

Is there any further documentation that describes half precision sparse MM?

@ducdauge
Copy link

@cddavis93 I replicated @IvanYashchuk's result with torch-1.11.0+cu113

@puddingfjz
Copy link

@IvanYashchuk, is cusparse used for the ``b @ a'' in

CSR matrix - dense matrix multiplication is now supported for float16:

In [1]: import torch

In [2]: a = torch.randn(3,2).half().cuda()
   ...: i = torch.LongTensor([[0, 1, 1],  [2, 0, 2]])
   ...: v = torch.FloatTensor([3, 4, 5])
   ...: b = torch.sparse.FloatTensor(i, v, torch.Size([2,3])).half().cuda()

In [3]: b = b.to_sparse_csr()

In [4]: b @ a
Out[4]:
tensor([[-0.6729, -1.0430],
        [-1.8916,  0.8125]], device='cuda:0', dtype=torch.float16)

Where can I find the code for this part?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
feature A request for a proper, new feature. module: half Related to float16 half-precision floats module: sparse Related to torch.sparse triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Development

No branches or pull requests

7 participants