Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

UserWarning: Plan failed with a cudnnException #121834

Closed
bhack opened this issue Mar 13, 2024 · 51 comments
Closed

UserWarning: Plan failed with a cudnnException #121834

bhack opened this issue Mar 13, 2024 · 51 comments
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Milestone

Comments

@bhack
Copy link
Contributor

bhack commented Mar 13, 2024

🐛 Describe the bug

Compiling this forward https://github.com/yoxu515/aot-benchmark/blob/paot/networks/engines/aotv3_engine.py#L35-L110

I got this warning.

Error logs

UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1710229288018/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,

And after few inputs I got:
#121504 (comment)
/cc @ezyang @gchanan @zou3519 @kadeng @csarofeen @ptrblck @xwang233 @msaroufim @bdhirsh @anijain2305 @chauhang @williamwen42

Here in the attachment the generated inductor[
c5nhr6q2xpuk52rh5thx56utuj6tjvxobjgjbd3rsdjvwggjys3d.py.txt
](url)

Minified repro

No response

Versions

Last official pytorch-nightly image.

@bhack
Copy link
Contributor Author

bhack commented Mar 13, 2024

Lowering the input resolution a bit in another run I don't see #121504 (comment) (as documented in that ticket) and I see these extra messages in the log:

/tmp/torchinductor_root/py/cpylrdfke46tta45o5xnxi77ex3ja2o5vdxsbjbcnp66kgd7vwqd.py:615: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1710229288018/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  buf3 = extern_kernels.convolution(buf0, buf1, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
/tmp/torchinductor_root/py/cpylrdfke46tta45o5xnxi77ex3ja2o5vdxsbjbcnp66kgd7vwqd.py:644: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1710229288018/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  buf13 = extern_kernels.convolution(buf10, buf11, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)
/tmp/torchinductor_root/rd/crdjdt7nq5zpiv2qjswdnkkjyqhawkfoeb5jm6t3lthc3dr3vmbq.py:615: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1710229288018/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  buf3 = extern_kernels.convolution(buf0, buf1, stride=(1, 1), padding=(1, 1), dilation=(1, 1), transposed=False, output_padding=(0, 0), groups=1, bias=None)

@jansel jansel added the module: cudnn Related to torch.backends.cudnn, and CuDNN support label Mar 23, 2024
@jansel
Copy link
Contributor

jansel commented Mar 23, 2024

@ezyang do you know who is the right person to look at this. The warning is coming from the eager convolution kernel.

@jansel jansel added the triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module label Mar 23, 2024
@ezyang
Copy link
Contributor

ezyang commented Mar 24, 2024

It sounds like something is wrong with our size/stride meta. cc @eellison

@jansel jansel added triage review and removed triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Mar 24, 2024
@eellison
Copy link
Contributor

I think you are linking the wrong inductor code. There is no conv in that output. Can you please include include a full repro ?

@eellison eellison added the needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user label Mar 25, 2024
@bhack
Copy link
Contributor Author

bhack commented Mar 25, 2024

I have many inductor code .py when I am compiling that forward on a clean env.
What I need to share?

@bhack
Copy link
Contributor Author

bhack commented Mar 25, 2024

If you want instead to reproduce it from the source code @williamwen42 has already some instruction at
#121504 (comment)
As we are already reproducing many compiler issue on the same model/repo.

@eellison
Copy link
Contributor

If you run with TORCH_COMPILE_DEBUG=1 and dump the full output that will be sufficient to repro here.

I have many inductor code .py when I am compiling that forward on a clean env. What I need to share?

@bhack
Copy link
Contributor Author

bhack commented Mar 26, 2024

If you run with TORCH_COMPILE_DEBUG=1 and dump the full output that will be sufficient to repro here.

You can find it here:
compile_dump.txt

@eellison
Copy link
Contributor

eellison commented Mar 26, 2024

Looking at the dump here - the warning happens prior to dynamo tracing so I'm not sure that it is a fake tensor (or pt2) issue.

/opt/conda/lib/python3.10/site-packages/torch/functional.py:512: UserWarning: torch.meshgrid: in an upcoming release, it will be required to pass the indexing argument. (Triggered internally at /opt/conda/conda-bld/pytorch_1710315938922/work/aten/src/ATen/native/TensorShape.cpp:3587.)
  return _VF.meshgrid(tensors, **kwargs)  # type: ignore[attr-defined]
/opt/conda/lib/python3.10/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1710315938922/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,
I0326 01:09:10.597000 132385496905536 torch/_dynamo/logging.py:55] [0/0] Step 1: torchdynamo start tracing forward /workspace/./networks/layers/attention.py:310

@bhack
Copy link
Contributor Author

bhack commented Mar 26, 2024

Looking at the dump here - the warning happens prior to dynamo tracing so I'm not sure that it is a fake tensor (or pt2) issue.

But I could confirm that without decorating with @torch.compile that forward we don't have that warning.

@bhack
Copy link
Contributor Author

bhack commented Mar 26, 2024

P.s. I meant without decorating.

@jansel jansel added triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module and removed triage review labels Mar 26, 2024
@nagadomi
Copy link

I got this warning after updating to PyTorch 2.3.0 today.
This warning only appears when nn.Conv2d(8, 1, 3) is called with fp16.

reproduction code

import torch
import torch.nn as nn

conv_no_warn = nn.Conv2d(8, 3, kernel_size=3, stride=1, padding=0).eval().cuda()
conv_warn = nn.Conv2d(8, 1, kernel_size=3, stride=1, padding=0).eval().cuda()
x = torch.rand((1, 8, 546, 392)).cuda()

with torch.inference_mode(), torch.autocast(device_type="cuda"):
    # No warning
    conv_no_warn(x)
    # UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED
    conv_warn(x)

@agunapal
Copy link
Contributor

Seeing this consistently with max-autotune mode of torch.compile in PyTorch 2.3.0

@atalman atalman added this to the 2.3.1 milestone Apr 24, 2024
@atalman
Copy link
Contributor

atalman commented Apr 24, 2024

This can be seen here:

Testing smoke_test_conv2d with cuda for torch.float16
/opt/conda/envs/conda-env-8819486240_pypi/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)

https://github.com/pytorch/builder/actions/runs/8819486240/job/24210862873#step:11:4098

This was happening in nightly on March 13:
https://github.com/pytorch/builder/actions/runs/8267337962/job/22618718246#step:11:4083

However was fixed on March 14:
https://github.com/pytorch/builder/actions/runs/8283397156/job/22666518941#step:11:4097

@nagadomi
Copy link

nagadomi commented Apr 24, 2024

I had used torch 2.3.0+cu118, so I tried cu121 version and the above reproduction code did not produce any warning.
Then I reverted to cu118 version and the warning no longer appears in the above reproduced code.
However, my more complex application still produced warnings,
so I tried running it with python -W error .. command, I found that the module that produced the warning had changed from before.


Edit:
Here is the code that can reproduce that warning with cu118.

import torch
import torch.nn as nn

conv_warn = nn.Conv2d(768, 96, kernel_size=(1, 1), stride=(1, 1)).eval().cuda()
x = torch.rand((1, 768, 39, 28)).cuda()

with torch.inference_mode(), torch.autocast(device_type="cuda"):
    # UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED
    conv_warn(x)

In my current environment, this code consistently produces warnings, but the previous code stopped producing warnings after I ran it with cu121. Maybe something is cached?

@mvpatel2000
Copy link
Contributor

mvpatel2000 commented Apr 26, 2024

Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)

also encountering this in eager mode in our unit tests to upgrade to torch 2.3
mosaicml/composer#3213

@johnnynunez
Copy link

with pytorch 2.3.0 with ultralytics yolov8 same problem

@golemme
Copy link

golemme commented Apr 26, 2024

With pytorch 2.3.0 and ultralytics yolov8 I'm getting UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ..\aten\src\ATen\native\cudnn\Conv_v8.cpp:919.)
return F.conv2d(input, weight, bias, self.stride,

@L1pp
Copy link

L1pp commented May 8, 2024

I've encountered the same issue while using the latest official Docker image "pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime". Here is the warning message:

/opt/conda/lib/python3.10/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608935911/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass

System Info:
Docker Image: pytorch/pytorch:2.3.0-cuda12.1-cudnn8-runtime
GPU: NVIDIA GeForce RTX 3090
Operating System: Ubuntu 22.04.4 LTS

The specific training model is the official resnet18. It is worth noting that this error did not affect the progress of training.
So based on the above content, does this warning not affect the actual effect?
I am just providing feedback that this error message still needs to be resolved in the official version of pytorch 2.3.0.

@eqy
Copy link
Collaborator

eqy commented May 8, 2024

We are looking to get this resolved in 2.3.1, and yes the warning alone should not affect the results of training. It is basically saying that the first selected cuDNN algorithm could not run the workload---in this case the next selected cuDNN algorithm will be tried.

@pbagwell-phioptics
Copy link

@weidehai
Copy link

I encountered the same problem
torch 2.3.0 pypi_0 pypi
torchvision 0.18.0 pypi_0 pypi

nvcc: NVIDIA (R) Cuda compiler driver
Copyright (c) 2005-2022 NVIDIA Corporation
Built on Wed_Sep_21_10:33:58_PDT_2022
Cuda compilation tools, release 11.8, V11.8.89
Build cuda_11.8.r11.8/compiler.31833905_0

NVIDIA-SMI 550.67 Driver Version: 550.67 CUDA Version: 12.4

Ubuntu 22.04.1 LTS

UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)

I don't know if it will affect the final execution result of the program

@atalman
Copy link
Contributor

atalman commented May 13, 2024

Closing this since cherry-picking PR posted: #125790
Pending validation for 2.3.1 rc1

@atalman atalman closed this as completed May 13, 2024
@williamwen42 williamwen42 removed their assignment May 13, 2024
@nmahammad
Copy link

i had the same issue : '' torch/nn/modules/conv.py:952: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return F.conv_transpose2d(
''

@stanleylcao
Copy link

I have also just got this error. I don't believe I saw this yesterday? Maybe it resurfaced?

@nmahammad
Copy link

nmahammad commented May 17, 2024

In my case the torch version is not even 2.3 , it is 2.2.1 and it started to have that issue starting from yesterday, can it be related to an internal problem in my GPU ?

error : ' CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return F.conv_transpose2d( '

Why is this issue closed ? Since the error is still not solved

@nmahammad
Copy link

Update : venv/lib/python3.9/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass --- now the error is a bit different

@rafa-br34
Copy link

rafa-br34 commented May 19, 2024

I get this, but it doesn't seem to interfere with training.

/home/rafa/Python3_11_6_VEnv/lib/python3.11/site-packages/torch/nn/modules/conv.py:456: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv2d(input, weight, bias, self.stride,
/home/rafa/Python3_11_6_VEnv/lib/python3.11/site-packages/torch/autograd/graph.py:744: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return Variable._execution_engine.run_backward(  # Calls into the C++ engine to run the backward pass
/home/rafa/Python3_11_6_VEnv/lib/python3.11/site-packages/torch/nn/modules/conv.py:952: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv_transpose2d(

Driver Version: 550.54.15, CUDA Version: 12.4

@HoseinHashemi
Copy link

I get the same UserWarning with Pytorch 2.3.0 and CUDA11.8

@GKaviani
Copy link

GKaviani commented May 24, 2024

I have similar problem (to some extent)

I got the error while using resnet3d that gets sequence of video/frames with batchsize 16 , sequence length 16 on following version of torch and cudnn:

pytorch                   2.3.0           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch
torchaudio                2.3.0               py310_cu121    pytorch
torchmetrics              1.4.0.post0        pyhd8ed1ab_0    conda-forge
torchtriton               2.3.0                     py310    pytorch
torchvision               0.18.0              py310_cu121    pytorch
cuda-cudart               12.1.105                      0    nvidia
cuda-cupti                12.1.105                      0    nvidia
cuda-libraries            12.1.0                        0    nvidia
cuda-nvrtc                12.1.105                      0    nvidia
cuda-nvtx                 12.1.105                      0    nvidia
cuda-opencl               12.4.127                      0    nvidia
cuda-runtime              12.1.0                        0    nvidia
pytorch                   2.3.0           py3.10_cuda12.1_cudnn8.9.2_0    pytorch
pytorch-cuda              12.1                 ha16c6d3_5    pytorch
pytorch-mutex             1.0                        cuda    pytorch

user/anaconda3/envs/DARai/lib/python3.10/site-packages/torch/nn/modules/conv.py:605: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at /opt/conda/conda-bld/pytorch_1712608935911/work/aten/src/ATen/native/cudnn/Conv_v8.cpp:919.)
  return F.conv3d(

@nmahammad
Copy link

The issue was solved when I downgraded torch to 2.2.2

@ClancyAyres
Copy link

The issue was solved when I downgraded torch to 2.2.2
+1

@atalman
Copy link
Contributor

atalman commented May 31, 2024

@Yanall-Boutros
Copy link

I can confirm this is an issue with 2.3.0.

I have a nix flake.lock pinning torch and torchaudio to 2.3.0+cu121, and a separate pinning them to 2.2.2+cu121
torch230flake.lock.txt
torch_222flake.lock.txt

When I run tortoise/do_tts.py --voice daniel --text test from this repo: https://github.com/Yanall-Boutros/tortoise-tts-poetry2nix/

The error: python3.11-torch-2.3.0/lib/python3.11/site-packages/torch/nn/modules/conv.py:303: UserWarning: Plan failed with a cudnnException: CUDNN_BACKEND_EXECUTION_PLAN_DESCRIPTOR: cudnnFinalize Descriptor Failed cudnn_status: CUDNN_STATUS_NOT_SUPPORTED (Triggered internally at ../aten/src/ATen/native/cudnn/Conv_v8.cpp:919.) is eventually when transforming auto-regressive outputs into audio

@HoseinHashemi
Copy link

Downgrading to Pytorch 2.2.2 solved the issue.

@leng-yue
Copy link
Contributor

leng-yue commented Jun 5, 2024

Does this cause any performance degrade?

@eqy
Copy link
Collaborator

eqy commented Jun 5, 2024

Should not cause performance degradation as after the first iteration the failing config will be skipped.
Please note it is fixed in 2.3.1.

@HoseinHashemi
Copy link

Does this cause any performance degrade?

I have not noticed any performance degradation yet. You only get the warning at the first iteration anyway.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
high priority module: cudnn Related to torch.backends.cudnn, and CuDNN support needs reproduction Someone else needs to try reproducing the issue given the instructions. No action needed from user oncall: pt2 triage review triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests