Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We鈥檒l occasionally send you account related emails.

Already on GitHub? Sign in to your account

Using nn.Conv2d with padding="same" supports a stride of 2, however it currently fails due to an error message #67551

Closed
ProGamerGov opened this issue Oct 29, 2021 · 4 comments
Labels
module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@ProGamerGov
Copy link
Contributor

ProGamerGov commented Oct 29, 2021

馃悰 Bug

The error message for _ConvNd indicates that nn.Conv2d does not currently support a stride of 2 when using same padding. However, upon disabling the error it seems to work correctly, which leads me to wonder if the error checking needs to be adjusted.

This line here appears to be the root cause of the issue: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L93-L94

            if padding == 'same' and any(s != 1 for s in stride):
                raise ValueError("padding='same' is not supported for strided convolutions")

To Reproduce

# Fails despite being able to work correctly
conv_layer = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), \
    stride=(2, 2), dilation=(1, 1), groups=1, bias=True, padding="same")

Trying to specify a stride=(2, 2) with padding="same" for a nn.Conv2d layer, results in the following error message:

ValueError                                Traceback (most recent call last)

<ipython-input-2-4308db99eb4d> in <module>()
----> 1 conv_layer = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), dilation=(1, 1), groups=1, bias=True, padding="same")

1 frames

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias, padding_mode, device, dtype)
    430         super(Conv2d, self).__init__(
    431             in_channels, out_channels, kernel_size_, stride_, padding_, dilation_,
--> 432             False, _pair(0), groups, bias, padding_mode, **factory_kwargs)
    433 
    434     def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]):

/usr/local/lib/python3.7/dist-packages/torch/nn/modules/conv.py in __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation, transposed, output_padding, groups, bias, padding_mode, device, dtype)
     92                         padding, valid_padding_strings))
     93             if padding == 'same' and any(s != 1 for s in stride):
---> 94                 raise ValueError("padding='same' is not supported for strided convolutions")
     95 
     96         valid_padding_modes = {'zeros', 'reflect', 'replicate', 'circular'}

ValueError: padding='same' is not supported for strided convolutions

I pulled the actual code for the same padding from here: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L110-L123, and setup a test to showcase that it works.

import torch
import torch.nn.functional as F


kernel_size=(7, 7)
stride=(2, 2)
dilation=(1, 1)

# Conv2d layer with a stride of 2
conv_layer_s2 = torch.nn.Conv2d(in_channels=3, out_channels=64, kernel_size=kernel_size, stride=stride, dilation=dilation, groups=1, bias=True)

# PyTorch's same padding calculations taken from ConvNd code
_reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
for d, k, i in zip(dilation, kernel_size, range(len(kernel_size) - 1, -1, -1)):
    total_padding = d * (k - 1)
    left_pad = total_padding // 2
    _reversed_padding_repeated_twice[2 * i] = left_pad
    _reversed_padding_repeated_twice[2 * i + 1] = (total_padding - left_pad)
                    
# Create test input
input_test = torch.zeros(1, 3, 224, 224)

# Pad input like in ConvNd code
input_p = F.pad(input_test, _reversed_padding_repeated_twice)

out_conv = conv_layer_s2(input_p)
print(out_conv.shape)

# Output tensor is expected shape
#>> torch.Size([1, 64, 112, 112])

Expected behavior

Creating an nn.Conv2d layer with padding=same and stride=(2, 2) should work without issue.

Environment

Collecting environment information...
PyTorch version: 1.9.0+cu111
Is debug build: False
CUDA used to build PyTorch: 11.1
ROCM used to build PyTorch: N/A

OS: Ubuntu 18.04.5 LTS (x86_64)
GCC version: (Ubuntu 7.5.0-3ubuntu1~18.04) 7.5.0
Clang version: 6.0.0-1ubuntu2 (tags/RELEASE_600/final)
CMake version: version 3.12.0
Libc version: glibc-2.26

Python version: 3.7.12 (default, Sep 10 2021, 00:21:48)  [GCC 7.5.0] (64-bit runtime)
Python platform: Linux-5.4.104+-x86_64-with-Ubuntu-18.04-bionic
Is CUDA available: False
CUDA runtime version: 11.1.105
GPU models and configuration: Could not collect
Nvidia driver version: Could not collect
cuDNN version: Probably one of the following:
/usr/lib/x86_64-linux-gnu/libcudnn.so.7.6.5
/usr/lib/x86_64-linux-gnu/libcudnn.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_adv_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_cnn_train.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_infer.so.8.0.5
/usr/lib/x86_64-linux-gnu/libcudnn_ops_train.so.8.0.5
HIP runtime version: N/A
MIOpen runtime version: N/A

Versions of relevant libraries:
[pip3] numpy==1.19.5
[pip3] torch==1.9.0+cu111
[pip3] torchsummary==1.5.1
[pip3] torchtext==0.10.0
[pip3] torchvision==0.10.0+cu111
[conda] Could not collect

The environment was a Colab instance, but the issue should occur with all PyTorch versions that support nn.Conv2d with padding="same".

Additional context

It's not clear why this working feature is blocked by an error message at the moment.

cc @albanD @mruberry @jbschlosser @walterddr

@jbschlosser
Copy link
Contributor

@ProGamerGov padding='same' is restricted for stride > 1 because of ambiguity about how it should behave for this case, as discussed in #3867.

This was done to move forward with some support while avoiding violating the expectations of those who expect the exact same behavior that TensorFlow has when stride > 1.

TensorFlow's behavior is controversial but may need to be provided as an option. I previously suggested implementing both TensorFlow's behavior and the current input-size-independent approach as two separate modes.

Out of curiosity, do you need the exact same behavior as TensorFlow for your use case or is the current approach that conceptually accomplishes the same thing acceptable for you?

@jbschlosser jbschlosser added module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module labels Oct 29, 2021
@ProGamerGov
Copy link
Contributor Author

ProGamerGov commented Oct 29, 2021

@jbschlosser Thank you for the quick reply! I'm currently using a custom wrapper over the nn.Conv2d layer to support models that utilize TensorFlow's same padding. The reason I'm using these models is for an upcoming addition to pytorch/captum.

import torch
import torch.nn.functional as F

class Conv2dSame(torch.nn.Conv2d):

    def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ih, iw = x.size()[-2:]

        pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
        pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])

        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
            )
        return F.conv2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

conv_layer_s2_same = Conv2dSame(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
out = conv_layer_s2_same(torch.zeros(1, 3, 224, 224))

The solution was written by a previous researcher before same padding was added to nn.Conv2d. I think the equation it's using requires the same exact behavior as in TensorFlow. This solution doesn't work with things like torch.fx.

@jbschlosser
Copy link
Contributor

Great - thanks for the additional info on your use case! Sounds like it's important that we provide an option for TF-exact same padding.

FYI I'm going to close this issue to consolidate discussion on the original issue #3867.

YHJYH added a commit to YHJYH/Machine_Learning_Dictionary that referenced this issue Feb 25, 2023
bug fix: ValueError: padding='same' is not supported for strided convolutions.
See: pytorch/pytorch#67551
@huangmozhilv
Copy link

huangmozhilv commented Dec 4, 2023

@jbschlosser Thank you for the quick reply! I'm currently using a custom wrapper over the nn.Conv2d layer to support models that utilize TensorFlow's same padding. The reason I'm using these models is for an upcoming addition to pytorch/captum.

import torch
import torch.nn.functional as F

class Conv2dSame(torch.nn.Conv2d):

    def calc_same_pad(self, i: int, k: int, s: int, d: int) -> int:
        return max((math.ceil(i / s) - 1) * s + (k - 1) * d + 1 - i, 0)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        ih, iw = x.size()[-2:]

        pad_h = self.calc_same_pad(i=ih, k=self.kernel_size[0], s=self.stride[0], d=self.dilation[0])
        pad_w = self.calc_same_pad(i=iw, k=self.kernel_size[1], s=self.stride[1], d=self.dilation[1])

        if pad_h > 0 or pad_w > 0:
            x = F.pad(
                x, [pad_w // 2, pad_w - pad_w // 2, pad_h // 2, pad_h - pad_h // 2]
            )
        return F.conv2d(
            x,
            self.weight,
            self.bias,
            self.stride,
            self.padding,
            self.dilation,
            self.groups,
        )

conv_layer_s2_same = Conv2dSame(in_channels=3, out_channels=64, kernel_size=(7, 7), stride=(2, 2), groups=1, bias=True)
out = conv_layer_s2_same(torch.zeros(1, 3, 224, 224))

The solution was written by a previous researcher before same padding was added to nn.Conv2d. I think the equation it's using requires the same exact behavior as in TensorFlow. This solution doesn't work with things like torch.fx.

However, be cautious with this solution. It will cause extra GPU memory cost for a padded x. This will cause CUDA out of memory problem if x is very large. Therefore, it is suggested to be used only for stride>1. if still OOM, replace Conv2dSame(stride=2) with maxpooling.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
None yet
Development

No branches or pull requests

3 participants