Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Add padding='same' mode to conv{1,2,3}d (#45667)
Summary: Pull Request resolved: #45667 First part of #3867 (Pooling operators still to do) This adds a `padding='same'` mode to the interface of `conv{n}d`and `nn.Conv{n}d`. This should match the behaviour of `tensorflow`. I couldn't find it explicitly documented but through experimentation I found `tensorflow` returns the shape `ceil(len/stride)` and always adds any extra asymmetric padding onto the right side of the input. Since the `native_functions.yaml` schema doesn't seem to support strings or enums, I've moved the function interface into python and it now dispatches between the numerically padded `conv{n}d` and the `_conv{n}d_same` variant. Underscores because I couldn't see any way to avoid exporting a function into the `torch` namespace. A note on asymmetric padding. The total padding required can be odd if both the kernel-length is even and the dilation is odd. mkldnn has native support for asymmetric padding, so there is no overhead there, but for other backends I resort to padding the input tensor by 1 on the right hand side to make the remaining padding symmetrical. In these cases, I use `TORCH_WARN_ONCE` to notify the user of the performance implications. Test Plan: Imported from OSS Reviewed By: ejguan Differential Revision: D27170744 Pulled By: jbschlosser fbshipit-source-id: b3d8a0380e0787ae781f2e5d8ee365a7bfd49f22
- Loading branch information
1 parent
a8a1090
commit 04e0cbf
Showing
18 changed files
with
892 additions
and
65 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,30 @@ | ||
#pragma once | ||
|
||
namespace c10 { | ||
namespace detail { | ||
|
||
template<class...Ts> | ||
struct overloaded_t {}; | ||
|
||
template<class T0> | ||
struct overloaded_t<T0>:T0 { | ||
using T0::operator(); | ||
overloaded_t(T0 t0):T0(std::move(t0)) {} | ||
}; | ||
template<class T0, class...Ts> | ||
struct overloaded_t<T0, Ts...>:T0, overloaded_t<Ts...> { | ||
using T0::operator(); | ||
using overloaded_t<Ts...>::operator(); | ||
overloaded_t(T0 t0, Ts... ts): | ||
T0(std::move(t0)), | ||
overloaded_t<Ts...>(std::move(ts)...) | ||
{} | ||
}; | ||
|
||
} // namespace detail | ||
|
||
// Construct an overloaded callable combining multiple callables, e.g. lambdas | ||
template<class...Ts> | ||
detail::overloaded_t<Ts...> overloaded(Ts...ts){ return {std::move(ts)...}; } | ||
|
||
} // namespace c10 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.