Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Feature Request] Implement "same" padding for convolution operations? #3867

Open
qbx2 opened this issue Nov 25, 2017 · 85 comments
Open

[Feature Request] Implement "same" padding for convolution operations? #3867

qbx2 opened this issue Nov 25, 2017 · 85 comments
Assignees
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module

Comments

@qbx2
Copy link
Contributor

qbx2 commented Nov 25, 2017

The implementation would be easy, but could help many people suffered from the headache of calculating how many padding they need.

cc @ezyang @gchanan @zou3519 @bdhirsh @jbschlosser @albanD @mruberry @walterddr

@qbx2 qbx2 changed the title Implement same padding for convolution operations? [Feature Request] Implement same padding for convolution operations? Nov 26, 2017
@qbx2 qbx2 changed the title [Feature Request] Implement same padding for convolution operations? [Feature Request] Implement "same" padding for convolution operations? Nov 26, 2017
@soumith
Copy link
Member

soumith commented Dec 1, 2017

This seems worth doing. What is the interface you are proposing? like nn.Conv2d(..., padding="same") ?

@soumith soumith added this to neural-nets in Issue Categories Dec 1, 2017
@fmassa
Copy link
Member

fmassa commented Dec 2, 2017

Note yhat if you are looking for the same behavior of TensorFlow, the implementation will not be that straighforward, because the number of pixels to add depend on the input size. See https://github.com/caffe2/caffe2/blob/master/caffe2/proto/caffe2_legacy.proto for reference

@qbx2
Copy link
Contributor Author

qbx2 commented Dec 4, 2017

Thank you for indicating the issue and the reference.
To resolve the issue stated by @fmassa, I propose two interfaces.
First, as @soutmith mentioned, the first interface would be likenn.Conv*d(..., padding="same"), calculating the padding every forward() call.
However, it would be an inefficient way when the input shape is known in the initialization phase. Therefore, I suggest an interface like nn.CalcPadConv*d(<almost same parameters as Conv*d>). Using it, a user can calculate the padding using known width and height in initialization, and pass the output (the shape of padding) to the padding parameter of nn.Conv2d(...)
I'm not sure if the second proposal could be a premature optimization.
How do you think about these? Is there any idea of a better name?

@fmassa
Copy link
Member

fmassa commented Dec 4, 2017

I think the biggest source of inefficiency will come from the fact that we will need to add a F.pad layer before every other convolution that requires the padding=same case (because the amount of padding might not the same on the left and right sides), see for example how TensorFlow has to handle that in the cudnn case. So that means that the nn.CalcPadConv*d would be normally as expensive as a nn.Conv*d(..., padding="same").

This could be made more efficient if we supported different paddings for each side of the convolution (like in Caffe2, so left, right, top, bottom), but cudnn still doesn't support that so we would require the extra padding in those cases.

Also, I think if we add the padding="same" to nn.Conv*d, we should probably do the same for nn.*Pool*d, right?

I think what bothers me a bit is that users might expect the behavior of padding=same to be equivalent to TF, but they might not be expecting a performance drop.

What do you think?

@apaszke
Copy link
Contributor

apaszke commented Dec 4, 2017

Why would that be inefficient? couldn't we just compute the padding at every forward step? the cost should be tiny, so there's no need to optimize that. Maybe I don't fully understand the semantics, but I can't see why F.pad would be needed.

@soumith
Copy link
Member

soumith commented Dec 4, 2017

making padding dependent on input size is quite bad. We just had an internal discussion about this, with @Yangqing outlining why this is a bad idea for a variety of serialization and efficiency reasons.

@qbx2
Copy link
Contributor Author

qbx2 commented Dec 4, 2017

@fmassa, what I intended was to calculate "constant" padding shape in __init__() using nn.CalcPadConv*d(). As you said, this way won't just work when calculated padding is odd. Therefore, it is needed for F.pad layer to be added, or, support of F.conv*d for odd paddings should help.

EDIT: Then what I suggested should be a function and placed in, say, torch.nn.utils or torch.utils.

@qbx2
Copy link
Contributor Author

qbx2 commented Dec 5, 2017

In result, what I suggest is a simple utility function, like (pseudocode):

def calc_pad_conv1d(width, padding='same', check_symmetric=True, ... <params that conv1d has>):
    shape = <calculate padding>

    assert not check_symmetric or <shape is symmetric>, \
        'Calculated padding shape is asymmetric, which is not supported by conv1d. ' \ 
        'If you just want to get the value, consider using check_symmetric=False.'

    return shape


width = 100  # for example
padding = calc_pad_conv1d(width, ...)
m = nn.Conv1d(..., padding=padding)

Also, The function could be used with F.pad in user's favor.

@fmassa
Copy link
Member

fmassa commented Dec 5, 2017

@qbx2 maybe I don't understand fully your proposal, but if we want to replicate TensorFlow behavior I don't think this is enough.

Here is a snippet of what I think mimics TensorFlow SAME padding (I'm writing it down into the functional interface, so that nn.Conv2d can just call into F.conv2d_same_padding):

def conv2d_same_padding(input, weight, bias=None, stride=1, dilation=1, groups=1):
  input_rows = input.size(2)
  filter_rows = weight.size(2)
  effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
  out_rows = (input_rows + stride[0] - 1) // stride[0]
  padding_needed =
          max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
                  input_rows)
  padding_rows = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
  rows_odd = (padding_rows % 2 != 0)
  # same for padding_cols

  if rows_odd or cols_odd:
    input = F.pad(input, [0, int(cols_odd), 0, int(rows_odd)])

  return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // 2, padding_cols // 2),
                  dilation=dilation, groups=groups)
  

It was mostly copy-pasted from TensorFlow code in here and here.

As you can see, there is a lot of hidden things going on there, and that's why I think it might not be worth it adding a padding='same'. And I think not replicating the SAME behavior in TensorFlow is not ideal either.

Thoughts?

@qbx2
Copy link
Contributor Author

qbx2 commented Dec 5, 2017

@fmassa Yes, you're right. It may be inefficient to calculate the padding on every forward().

However, my proposal is NOT to calculate the padding every forward() call. A researcher (developer) may expect the sizes of images to nn.Conv2d before runtime. And if he/she wants the 'same' padding, he/she can use the function to calculate required padding to mimic 'SAME'.

For example, think the case that a researcher has images with 200x200, 300x300, 400x400. Then he/she can calculate paddings for the three cases in the initialization phase and just pass the images to F.pad() with the corresponding padding. Or he/she just change the padding field of nn.Conv2d before the forward() call, either. Refer to this:

>>> import torch
>>> import torch.nn as nn
>>> from torch.autograd import Variable
>>> m = nn.Conv2d(1,1,1)
>>> m(Variable(torch.randn(1,1,2,2))).shape
torch.Size([1, 1, 2, 2])
>>> m.padding = (1, 1)
>>> m(Variable(torch.randn(1,1,2,2))).shape
torch.Size([1, 1, 4, 4])

Yes, I just want to add the "padding calculating utility function" in pytorch core.

When the researcher wants dependent padding on each input image size, he/she can combine the function with F.pad() before passing the image to nn.Conv2d. I want to let the code writer decide whether to pad the inputs on every forward() call or not.

@imgyuri
Copy link

imgyuri commented Jan 29, 2018

Is there any plan of implementing a similar api in pytorch in the near future? People coming from a tensorflow / keras background will certainly appreciate it.

@fmassa
Copy link
Member

fmassa commented Jan 30, 2018

So, a basic padding calculation strategy (which does not gives the same results as TensorFlow, but the shapes are similar) is to have

def _get_padding(padding_type, kernel_size):
    assert padding_type in ['SAME', 'VALID']
    if padding_type == 'SAME':
        return tuple((k - 1) // 2 for k in kernel_size))
    return tuple(0 for _ in kernel_size)

Is that what you have in mind @im9uri ?

@imgyuri
Copy link

imgyuri commented Jan 31, 2018

It's similar to what I had in mind, but as you mentioned previously the calculation gets complicated with stride and dilation.

Also having such an api in other convolution operations such as ConvTranspose2d would be great.

@janluke
Copy link

janluke commented May 9, 2018

I think that "sliding-window operators" should all support asymmetric padding.

About the "same" argument...
@soumith Can you explain why making padding depending on the input size is bad, please?
If that's a problem, anyway, a pragmatic solution could be to require stride == 1 when using "same". For stride == 1, the padding doesn't depend on the input size and can be computed a single time. The constructor should raise a ValueError if the user attempts to use padding='same' with stride > 1.

I know, it's not the cleanest solution but the constraint sounds reasonable enough to me given that:

  1. the original semantic of the label "same" was introduced for not strided convolutions and was: the output has the same size of the input; of course, this is not true in tensorflow for stride > 1 and that makes the use of the word "same" a bit misleading IMO;
  2. it would cover 99% of the cases one wants to use "same"; I can barely imagine a case when someone really needs the behavior of tensorflow for stride > 1, while if we give to "same" its original semantic, well, of course it doesn't make any sense to use a strided convolution if you want the output has the same size of the input.

@teucer
Copy link

teucer commented Jun 25, 2018

conv2d documentation gives the explicit formulas for output sizes. Equating e.g. Hout with Hin one can solve for the padding:

def _get_padding(size, kernel_size, stride, dilation):
    padding = ((size - 1) * (stride - 1) + dilation * (kernel_size - 1)) //2
    return padding

@sidr97
Copy link

sidr97 commented Jun 26, 2018

Since same padding means padding = (kernel_size - stride)//2, what if padding = "same" is introduced such that when written, it automatically reads kernel size and stride (as that is also mentioned in nn.Conv2d) and applies padding automatically accordingly

@kylemcdonald
Copy link

kylemcdonald commented Jul 25, 2018

Here is a very simple Conv2d layer with same padding for reference. It only support square kernels and stride=1, dilation=1, groups=1.

class Conv2dSame(torch.nn.Module):
    def __init__(self, in_channels, out_channels, kernel_size, bias=True, padding_layer=torch.nn.ReflectionPad2d):
        super().__init__()
        ka = kernel_size // 2
        kb = ka - 1 if kernel_size % 2 == 0 else ka
        self.net = torch.nn.Sequential(
            padding_layer((ka,kb,ka,kb)),
            torch.nn.Conv2d(in_channels, out_channels, kernel_size, bias=bias)
        )
    def forward(self, x):
        return self.net(x)
    
c = Conv2dSame(1,3,5)
print(c(torch.rand((16,1,10,10))).shape)

# torch.Size([16, 3, 10, 10])

@traviskaufman
Copy link

traviskaufman commented Dec 14, 2018

If this is still being evaluated for being added to PyTorch, then regarding the tradeoffs between complexity / inefficiency vs. ease-of-use for developers:

In the road to 1.0 blog post, it states:

PyTorch’s central goal is to provide a great platform for research and hackability. So, while we add all these [production-use] optimizations, we’ve been working with a hard design constraint to never trade these off against usability.

Anecdotally, I come from a background of using Keras as well as the original tf.layers / estimator APIs. All have support for same padding. I'm currently reimplementing a convnet I had originally written in TF with PyTorch, and the fact that I've had to build in the arithmetic for zero-padding myself has cost me about a half-day of time.

If the "central goal" really is focused on usability, than I'd argue that even if there's an efficiency hit to computing zero-padding on every forward pass (as mentioned above), the time saved in terms of developer efficiency and maintainability (e.g. not having to write custom code to compute zero padding) may be worth the tradeoff. Thoughts?

@bionicles
Copy link

I would use this feature

@tremblerz
Copy link

It doesn't make sense to me why can't an optional API of padding=SAME be offered? If someone is willing to incur the additional cost of padding then let them do so. For many researchers, quick prototyping is a requirement.

@BoPengGit
Copy link

Yes, if someone can please add and approve this, it would be great.

@leijurv
Copy link

leijurv commented Jan 24, 2019

Definitely add this, conner wants it.

@lucasjinreal
Copy link

Does pytorch support it now? Can it using same operation like first in VGG, set padding = (kernel_size-1)/2 ?
The VGG network can make output size does not change in the first group. Then you can using stride to resize the featuremap, does it sounds ok?

@lucasjinreal
Copy link

lucasjinreal commented Jan 29, 2019

Here is one example to call padding same conv2d from deepfakes:

# modify con2d function to use same padding
# code referd to @famssa in 'https://github.com/pytorch/pytorch/issues/3867'
# and tensorflow source code

import torch.utils.data
from torch.nn import functional as F

import math
import torch
from torch.nn.parameter import Parameter
from torch.nn.functional import pad
from torch.nn.modules import Module
from torch.nn.modules.utils import _single, _pair, _triple


class _ConvNd(Module):

    def __init__(self, in_channels, out_channels, kernel_size, stride,
                 padding, dilation, transposed, output_padding, groups, bias):
        super(_ConvNd, self).__init__()
        if in_channels % groups != 0:
            raise ValueError('in_channels must be divisible by groups')
        if out_channels % groups != 0:
            raise ValueError('out_channels must be divisible by groups')
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.kernel_size = kernel_size
        self.stride = stride
        self.padding = padding
        self.dilation = dilation
        self.transposed = transposed
        self.output_padding = output_padding
        self.groups = groups
        if transposed:
            self.weight = Parameter(torch.Tensor(
                in_channels, out_channels // groups, *kernel_size))
        else:
            self.weight = Parameter(torch.Tensor(
                out_channels, in_channels // groups, *kernel_size))
        if bias:
            self.bias = Parameter(torch.Tensor(out_channels))
        else:
            self.register_parameter('bias', None)
        self.reset_parameters()

    def reset_parameters(self):
        n = self.in_channels
        for k in self.kernel_size:
            n *= k
        stdv = 1. / math.sqrt(n)
        self.weight.data.uniform_(-stdv, stdv)
        if self.bias is not None:
            self.bias.data.uniform_(-stdv, stdv)

    def __repr__(self):
        s = ('{name}({in_channels}, {out_channels}, kernel_size={kernel_size}'
             ', stride={stride}')
        if self.padding != (0,) * len(self.padding):
            s += ', padding={padding}'
        if self.dilation != (1,) * len(self.dilation):
            s += ', dilation={dilation}'
        if self.output_padding != (0,) * len(self.output_padding):
            s += ', output_padding={output_padding}'
        if self.groups != 1:
            s += ', groups={groups}'
        if self.bias is None:
            s += ', bias=False'
        s += ')'
        return s.format(name=self.__class__.__name__, **self.__dict__)


class Conv2d(_ConvNd):

    def __init__(self, in_channels, out_channels, kernel_size, stride=1,
                 padding=0, dilation=1, groups=1, bias=True):
        kernel_size = _pair(kernel_size)
        stride = _pair(stride)
        padding = _pair(padding)
        dilation = _pair(dilation)
        super(Conv2d, self).__init__(
            in_channels, out_channels, kernel_size, stride, padding, dilation,
            False, _pair(0), groups, bias)

    def forward(self, input):
        return conv2d_same_padding(input, self.weight, self.bias, self.stride,
                        self.padding, self.dilation, self.groups)


# custom con2d, because pytorch don't have "padding='same'" option.
def conv2d_same_padding(input, weight, bias=None, stride=1, padding=1, dilation=1, groups=1):

    input_rows = input.size(2)
    filter_rows = weight.size(2)
    effective_filter_size_rows = (filter_rows - 1) * dilation[0] + 1
    out_rows = (input_rows + stride[0] - 1) // stride[0]
    padding_needed = max(0, (out_rows - 1) * stride[0] + effective_filter_size_rows -
                  input_rows)
    padding_rows = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    rows_odd = (padding_rows % 2 != 0)
    padding_cols = max(0, (out_rows - 1) * stride[0] +
                        (filter_rows - 1) * dilation[0] + 1 - input_rows)
    cols_odd = (padding_rows % 2 != 0)

    if rows_odd or cols_odd:
        input = pad(input, [0, int(cols_odd), 0, int(rows_odd)])

    return F.conv2d(input, weight, bias, stride,
                  padding=(padding_rows // 2, padding_cols // 2),
                  dilation=dilation, groups=groups)

facebook-github-bot pushed a commit that referenced this issue Mar 18, 2021
Summary:
Pull Request resolved: #45667

First part of #3867 (Pooling operators still to do)

This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input.

Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace.

A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even  and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications.

Test Plan: Imported from OSS

Reviewed By: ejguan

Differential Revision: D27170744

Pulled By: jbschlosser

fbshipit-source-id: b3d8a0380e0787ae781f2e5d8ee365a7bfd49f22
@MeteorsHub
Copy link

MeteorsHub commented Apr 19, 2021

It is so stupid every time I use conv I should calculate the padding size. It wastes almost half the code writing time. You can write a whole paragraph saying SAME is not appropriate at only some time, but it is a common sense that you should give the user an option.

@peterbell10
Copy link
Collaborator

Same padding for un-strided convolutions will be in the next realease, PyTorch 1.9. This works today on pytorch-nightly:

>>> import torch
>>> module = torch.nn.Conv2d(1, 1, 5, padding='same')
>>> x = torch.randn(1, 1, 20, 20)
>>> module(x).shape
torch.Size([1, 1, 20, 20])

@jbschlosser
Copy link
Contributor

To get the conversation going, I propose the following variations for strided convolutions (i.e. stride > 1):

  • padding='same'
    • Non-input-size dependent approach
    • total_padding = dilation * (kernelSize - 1)
  • padding='same_minimal' (with doc warnings explaining the downsides)
    • TensorFlow's input-size-dependent approach that minimizes the total padding
    • total_padding = max(0, dilation * (kernel_size - 1) - (input_size - 1) % stride)

@gchanan gchanan modified the milestones: 1.9.0, 1.10.0 May 10, 2021
@jbschlosser
Copy link
Contributor

Removing high-pri since support is now in for un-strided convolutions. We can bump the priority back up if enough people want support for strided convolutions.

@vadimkantorov
Copy link
Contributor

It could be useful to check some recently released tensorflow generative models and check if they use strided same-convs. I think they do...

@digital-idiot
Copy link

To get the conversation going, I propose the following variations for strided convolutions (i.e. stride > 1):

* `padding='same'`
  
  * Non-input-size dependent approach
  * `total_padding = dilation * (kernelSize - 1)`

* `padding='same_minimal'` (with doc warnings explaining the downsides)
  
  * TensorFlow's input-size-dependent approach that minimizes the total padding
  * `total_padding = max(0, dilation * (kernel_size - 1) - (input_size - 1) % stride)`

A slightly different formula explained here with respect to TensorFlow. Proposed padding='same_minimal' seems interesting. Since it is input size dependent, the padding has to be done inside of forward? I would like to know what are the downsides in this case?

@malfet malfet removed this from the 1.10.0 milestone Oct 5, 2021
@malfet
Copy link
Contributor

malfet commented Oct 5, 2021

Removing milestone tag for now

@ProGamerGov
Copy link
Contributor

ProGamerGov commented Oct 23, 2021

SAME padding support was added to nn.Conv2d in the latest version of PyTorch! Though it doesn't support stride sizes other than 1 yet, so for example my layer with a stride size of 2 won't work.

Hopefully different striding sizes will eventually be supported?

@ProGamerGov
Copy link
Contributor

So, I pulled out the same padding code here: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L110-L123, tested it with a stride of 2 for a nn.Conv2d layer, and it appears to work properly.

            self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
            if padding == 'same':
                for d, k, i in zip(dilation, kernel_size,
                                   range(len(kernel_size) - 1, -1, -1)):
                    total_padding = d * (k - 1)
                    left_pad = total_padding // 2
                    self._reversed_padding_repeated_twice[2 * i] = left_pad
                    self._reversed_padding_repeated_twice[2 * i + 1] = (
                        total_padding - left_pad)

This leads me to wonder why this error message exists if it's not accurate: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L93-L94

            if padding == 'same' and any(s != 1 for s in stride):
                raise ValueError("padding='same' is not supported for strided convolutions")

Why is there an error message for same padding with a stride of 2, if it works correctly when the error message is removed?

@jbschlosser
Copy link
Contributor

Reinstating high-pri due to activity.

@JakobHavtorn
Copy link

So, I pulled out the same padding code here: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L110-L123, tested it with a stride of 2 for a nn.Conv2d layer, and it appears to work properly.

            self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size)
            if padding == 'same':
                for d, k, i in zip(dilation, kernel_size,
                                   range(len(kernel_size) - 1, -1, -1)):
                    total_padding = d * (k - 1)
                    left_pad = total_padding // 2
                    self._reversed_padding_repeated_twice[2 * i] = left_pad
                    self._reversed_padding_repeated_twice[2 * i + 1] = (
                        total_padding - left_pad)

This leads me to wonder why this error message exists if it's not accurate: https://github.com/pytorch/pytorch/blob/master/torch/nn/modules/conv.py#L93-L94

            if padding == 'same' and any(s != 1 for s in stride):
                raise ValueError("padding='same' is not supported for strided convolutions")

Why is there an error message for same padding with a stride of 2, if it works correctly when the error message is removed?

Doesn't it seem that the transposed convolution with padding="same" does indeed not work properly when stride > 1?

@ProGamerGov
Copy link
Contributor

I only tested with the nn.Conv2d layer, and at the time I was not aware of why the error message was so broad despite the functionality appearing to exist. The issue stems from a different in how TensorFlow implemented same padding vs what would be more logical, and that's why the error message exists in it's current form.

@addisonklinke
Copy link

addisonklinke commented Nov 19, 2021

@JakobHavtorn I don't think the currentl implementation works for stride > 1. I made a manual Conv2dSame class so you can bypass the error in PyTorch 1.10.0, but you can see the spatial dimensions are not retained

import collections
from itertools import repeat
import torch
from torch import nn
import torch.nn.functional as F


def _ntuple(n):
    """Copied from PyTorch since it's not importable as an internal function

    https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/utils.py#L6
    """
    def parse(x):
        if isinstance(x, collections.abc.Iterable):
            return tuple(x)
        return tuple(repeat(x, n))

    return parse


_pair = _ntuple(2)


class Conv2dSame(nn.Module):
    """Manual convolution with same padding

    Although PyTorch >= 1.10.0 supports ``padding='same'`` as a keyword argument,
    this does not export to CoreML as of coremltools 5.1.0, so we need to
    implement the internal torch logic manually. Currently the ``RuntimeError`` is

    "PyTorch convert function for op '_convolution_mode' not implemented"

    Also same padding is not supported for strided convolutions at the moment
    https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/conv.py#L93
    """

    def __init__(self, in_channels, out_channels, kernel_size, stride=1, dilation=1, **kwargs):
        """Wrap base convolution layer

        See official PyTorch documentation for parameter details
        https://pytorch.org/docs/stable/generated/torch.nn.Conv2d.html
        """
        super().__init__()
        self.conv = nn.Conv2d(
            in_channels=in_channels,
            out_channels=out_channels,
            kernel_size=kernel_size,
            stride=stride,
            dilation=dilation,
            **kwargs)

        # Setup internal representations
        kernel_size_ = _pair(kernel_size)
        dilation_ = _pair(dilation)
        self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size_)

        # Follow the logic from ``nn._ConvNd``
        # https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/conv.py#L116
        for d, k, i in zip(dilation_, kernel_size_, range(len(kernel_size_) - 1, -1, -1)):
            total_padding = d * (k - 1)
            left_pad = total_padding // 2
            self._reversed_padding_repeated_twice[2 * i] = left_pad
            self._reversed_padding_repeated_twice[2 * i + 1] = (
                    total_padding - left_pad)

    def forward(self, imgs):
        """Setup padding so same spatial dimensions are returned

        All shapes (input/output) are ``(N, C, W, H)`` convention

        :param torch.Tensor imgs:
        :return torch.Tensor:
        """
        padded = F.pad(imgs, self._reversed_padding_repeated_twice)
        return self.conv(padded)


conv_same = Conv2dSame(3, 8, 3)
conv_same_stride = Conv2dSame(3, 8, 3, stride=2)
imgs = torch.randn(1, 3, 28, 28)
print(conv_same(imgs).shape)         # Correct [1, 8, 28, 28]
print(conv_same_stride(imgs).shape)  # Wrong   [1, 8, 14, 14]

@JakobHavtorn
Copy link

JakobHavtorn commented Nov 19, 2021

@addisonklinke Ah yes I get your point. I believe this comes down to what you expect.

Same padding is defined to keep the input and output shapes equal for unit strides. I often find it useful to be able to downsample an input some integer number of times in a "same padded" manner.

By this I mean that the length of the output along the convolved dimension(s) is exactly s times shorter than the input. In general you could write the output length o as

o = ceil(i / s)

where i is the input length and s is the stride.

So in your example, I actually find that the output has exactly the shape I would expect from a same-padded convolution with non-unit stride.

@addisonklinke
Copy link

@JakobHavtorn Ah interesting, I hadn't thought about it like that!

@lauriebyrum
Copy link

@addisonklinke : We just ran into this coremltools issue too, so I filed it at apple/coremltools#1363. But meanwhile, i'll try your workaround, which makes some sense. Thanks for sharing it!

@dav-ell
Copy link

dav-ell commented Apr 14, 2022

@addisonklinke's code worked for me, and did it without using math.ceil or math.floor, which makes it compatible with (the current) TensorRT release. I was getting all sorts of errors working with other implementations that were using ceil and floor, where it wasn't being calculated the same way in TensorRT (or Torch_TensorRT) as it was in Torch/Torchscript. Awesome work!

I also implemented this for MaxPool2D, which was mostly a copy/paste, and seems to work:

class MaxPool2dStaticSamePadding(nn.Module):
    """
    The real keras/tensorflow MaxPool2d with same padding
    """

    def __init__(self, *args, **kwargs):
        super().__init__()
        self.pool = nn.MaxPool2d(*args, **kwargs)
        self.stride = self.pool.stride
        self.kernel_size = self.pool.kernel_size

        if isinstance(self.stride, int):
            self.stride = [self.stride] * 2
        elif len(self.stride) == 1:
            self.stride = [self.stride[0]] * 2

        if isinstance(self.kernel_size, int):
            self.kernel_size = [self.kernel_size] * 2
        elif len(self.kernel_size) == 1:
            self.kernel_size = [self.kernel_size[0]] * 2

        # Setup internal representations
        kernel_size_ = _pair(self.kernel_size)
        dilation_ = _pair(1)
        self._reversed_padding_repeated_twice = [0, 0] * len(kernel_size_)

        # Follow the logic from ``nn._ConvNd``
        # https://github.com/pytorch/pytorch/blob/v1.10.0/torch/nn/modules/conv.py#L116
        for d, k, i in zip(dilation_, kernel_size_, range(len(kernel_size_) - 1, -1, -1)):
            total_padding = d * (k - 1)
            left_pad = total_padding // 2
            self._reversed_padding_repeated_twice[2 * i] = left_pad
            self._reversed_padding_repeated_twice[2 * i + 1] = (
                    total_padding - left_pad)

    def forward(self, x):
        """Setup padding so same spatial dimensions are returned

        All shapes (input/output) are ``(N, C, W, H)`` convention

        :param torch.Tensor imgs:
        :return torch.Tensor:
        """
        padded = F.pad(x, self._reversed_padding_repeated_twice)
        x = self.pool(padded)
        return x

@RichieHakim
Copy link

RichieHakim commented Feb 3, 2024

As of now, PyTorch's implementation of 'same' is different from SciPy/NumPy!
Specifically, PyTorch pads more at the end, SciPy pads more at the front.

I would like to point out that the convention used by torch (and tensorflow) for when the kernel ('weight') is even-length is the opposite of what is used in scipy and numpy. Looking through the thread above, I don't see any discussion of why this might be. Here are some code examples demonstrating this:

import torch
import torch.nn.functional as F

import numpy as np
import scipy.signal

# Example input signal and kernel
input_signal = np.ones(6).astype(np.float32)  # Tensor from 1 to 10
kernel = np.ones(4).astype(np.float32)  # Kernel of size 4

# Apply conv1d with 'same' padding
output_scipy = scipy.signal.convolve(input_signal, kernel, mode='same')

print(output_scipy)

# Example input signal and kernel
input_signal = torch.as_tensor(input_signal)  # Tensor from 1 to 10
kernel = torch.as_tensor(kernel)  # Kernel of size 4

# Apply conv1d with 'same' padding
output_torch = F.conv1d(input_signal[None, None, :], kernel[None, None, :], padding='same').squeeze().numpy()

print(output_torch)

result

SciPy:   [2. 3. 4. 4. 4. 3.]
PyTorch: [3. 4. 4. 4. 3. 2.]

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancement Not as big of a feature, but technically not a bug. Should be easy to fix high priority module: convolution Problems related to convolutions (THNN, THCUNN, CuDNN) module: nn Related to torch.nn needs design triaged This issue has been looked at a team member, and triaged and prioritized into an appropriate module
Projects
Issue Categories
neural-nets
Development

No branches or pull requests