Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Can't use gaussian_blur if sigma is a tensor on gpu #8401

Open
Xact-sniper opened this issue May 1, 2024 · 2 comments
Open

Can't use gaussian_blur if sigma is a tensor on gpu #8401

Xact-sniper opened this issue May 1, 2024 · 2 comments

Comments

@Xact-sniper
Copy link

Xact-sniper commented May 1, 2024

馃悰 Describe the bug

Admittedly perhaps an unconventional use, but I'm using gaussian_blur in my model to blur attention maps and I want to have the sigma be a parameter.

It would work, except for this function:

def _get_gaussian_kernel1d(kernel_size: int, sigma: float) -> Tensor:

x is not moved to the device that sigma is on.

I believe it is like this in all torchvision versions.

WORKS:

import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True)

blurred = gaussian_blur(torch.randn(1, 3, 256, 256), k, [s])
blurred.mean().backward()
print(s.grad)
>>> tensor(-4.6193e-05)

DOES NOT:

import torch
from torchvision.transforms.functional import gaussian_blur
k = 15
s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')

blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
blurred.mean().backward()
print(s.grad)

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
[D:\Temp\ipykernel_39000\3525683463.py](file:///D:/Temp/ipykernel_39000/3525683463.py) in <module>
      4 s = torch.tensor(0.3 * ((5 - 1) * 0.5 - 1) + 0.8, requires_grad = True, device='cuda')
      5 
----> 6 blurred = gaussian_blur(torch.randn(1, 3, 256, 256, device='cuda'), k, [s])
      7 blurred.mean().backward()
      8 print(s.grad)

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\functional.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/functional.py) in gaussian_blur(img, kernel_size, sigma)
   1361         t_img = pil_to_tensor(img)
   1362 
-> 1363     output = F_t.gaussian_blur(t_img, kernel_size, sigma)
   1364 
   1365     if not isinstance(img, torch.Tensor):

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in gaussian_blur(img, kernel_size, sigma)
    749 
    750     dtype = img.dtype if torch.is_floating_point(img) else torch.float32
--> 751     kernel = _get_gaussian_kernel2d(kernel_size, sigma, dtype=dtype, device=img.device)
    752     kernel = kernel.expand(img.shape[-3], 1, kernel.shape[0], kernel.shape[1])
    753 

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel2d(kernel_size, sigma, dtype, device)
    736     kernel_size: List[int], sigma: List[float], dtype: torch.dtype, device: torch.device
    737 ) -> Tensor:
--> 738     kernel1d_x = _get_gaussian_kernel1d(kernel_size[0], sigma[0]).to(device, dtype=dtype)
    739     kernel1d_y = _get_gaussian_kernel1d(kernel_size[1], sigma[1]).to(device, dtype=dtype)
    740     kernel2d = torch.mm(kernel1d_y[:, None], kernel1d_x[None, :])

[s:\anaconda3\envs\base_pip\lib\site-packages\torchvision\transforms\_functional_tensor.py](file:///S:/anaconda3/envs/base_pip/lib/site-packages/torchvision/transforms/_functional_tensor.py) in _get_gaussian_kernel1d(kernel_size, sigma)
    727 
    728     x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
--> 729     pdf = torch.exp(-0.5 * (x / sigma).pow(2))
    730     kernel1d = pdf / pdf.sum()
    731 

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!

I don't know about the convention, like whether device should be passed in, but the simplest fix I believe would just be to change:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size)
to:
728 x = torch.linspace(-ksize_half, ksize_half, steps=kernel_size).to(sigma.device)

Actually that won't when sigma is just a float. So I guess there could be a check for whether its a float or a float tensor.

Versions

[pip3] efficientunet-pytorch==0.0.6
[pip3] ema-pytorch==0.4.5
[pip3] flake8==6.0.0
[pip3] mypy-extensions==0.4.3
[pip3] numpy==1.24.3
[pip3] numpydoc==1.4.0
[pip3] pytorch-msssim==1.0.0
[pip3] siren-pytorch==0.1.7
[pip3] torch==2.2.2+cu118
[pip3] torch-cluster==1.6.0+pt113cu116
[pip3] torch_geometric==2.4.0
[pip3] torch-scatter==2.1.0+pt113cu116
[pip3] torch-sparse==0.6.16+pt113cu116
[pip3] torch-spline-conv==1.2.1+pt113cu116
[pip3] torch-tools==0.1.5
[pip3] torchaudio==2.2.2+cu118
[pip3] torchbearer==0.5.3
[pip3] torchmeta==1.8.0
[pip3] torchvision==0.17.2+cu118
[pip3] uformer-pytorch==0.0.8
[pip3] vit-pytorch==1.5.0
[conda] blas 1.0 mkl
[conda] efficientunet-pytorch 0.0.6 pypi_0 pypi
[conda] ema-pytorch 0.4.5 pypi_0 pypi
[conda] mkl 2021.4.0 haa95532_640
[conda] mkl-service 2.4.0 py39h2bbff1b_0
[conda] mkl_fft 1.3.1 py39h277e83a_0
[conda] mkl_random 1.2.2 py39hf11a4ad_0
[conda] numpy 1.24.3 pypi_0 pypi
[conda] numpydoc 1.4.0 py39haa95532_0
[conda] pytorch-cuda 11.6 h867d48c_1 pytorch
[conda] pytorch-msssim 1.0.0 pypi_0 pypi
[conda] pytorch-mutex 1.0 cuda pytorch
[conda] siren-pytorch 0.1.7 pypi_0 pypi
[conda] torch 1.13.0 pypi_0 pypi
[conda] torch-cluster 1.6.0+pt113cu116 pypi_0 pypi
[conda] torch-geometric 2.4.0 pypi_0 pypi
[conda] torch-scatter 2.1.0+pt113cu116 pypi_0 pypi
[conda] torch-sparse 0.6.16+pt113cu116 pypi_0 pypi
[conda] torch-spline-conv 1.2.1+pt113cu116 pypi_0 pypi
[conda] torch-tools 0.1.5 pypi_0 pypi
[conda] torchaudio 0.9.1 pypi_0 pypi
[conda] torchbearer 0.5.3 pypi_0 pypi
[conda] torchmeta 1.8.0 pypi_0 pypi
[conda] torchvision 0.17.2+cu118 pypi_0 pypi
[conda] uformer-pytorch 0.0.8 pypi_0 pypi
[conda] vit-pytorch 1.5.0 pypi_0 pypi

@Bhavay-2001
Copy link

Hi @pmeier, it is a good-first issue? Will it be suitable for a beginner?

@Bhavay-2001
Copy link

Bhavay-2001 commented May 9, 2024

Hi @Xact-sniper, I think a possible fix is that we can add torch.device to this function call here.

Can you pls send a reproducible code snippet?

@pmeier @NicolasHug any possible suggestions to this?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants