Skip to content

Commit

Permalink
[release/1.10] Fix adaptive_max_pool2d for channels-last on CUDA (#67697
Browse files Browse the repository at this point in the history
) (#69618)

Co-authored-by: Xiao Wang <24860335+xwang233@users.noreply.github.com>
  • Loading branch information
seemethere and xwang233 committed Dec 9, 2021
1 parent 0c91a70 commit 302ee7b
Show file tree
Hide file tree
Showing 2 changed files with 27 additions and 12 deletions.
38 changes: 27 additions & 11 deletions aten/src/ATen/native/cuda/AdaptiveMaxPooling2d.cu
Expand Up @@ -211,6 +211,9 @@ const Tensor& indices) {
int64_t osizeH = output_size[0];
int64_t osizeW = output_size[1];

const at::Tensor output_c = output.is_contiguous() ? output : at::empty(output.sizes(), output.options());
const at::Tensor indices_c = indices.is_contiguous() ? indices : at::empty(indices.sizes(), indices.options());

if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
int64_t isizeH = input.size(1);
Expand All @@ -223,8 +226,8 @@ const Tensor& indices) {
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf, kBFloat16, input.scalar_type(), "adaptive_max_pool2d_cuda", [&] {
scalar_t* input_data = input.data_ptr<scalar_t>();
scalar_t* output_data = output.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>();
scalar_t* output_data = output_c.data_ptr<scalar_t>();
int64_t* indices_data = indices_c.data_ptr<int64_t>();

// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
Expand Down Expand Up @@ -268,8 +271,8 @@ const Tensor& indices) {
"adaptive_max_pool2d_cuda",
[&] {
scalar_t* input_data = input_.data_ptr<scalar_t>();
scalar_t* output_data = output.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>();
scalar_t* output_data = output_c.data_ptr<scalar_t>();
int64_t* indices_data = indices_c.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
Expand All @@ -296,6 +299,13 @@ const Tensor& indices) {
C10_CUDA_KERNEL_LAUNCH_CHECK();
});
}
if (!output.is_contiguous()) {
output.copy_(output_c);
}
if (!indices.is_contiguous()) {
indices.copy_(indices_c);
}
}
TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
Expand All @@ -322,7 +332,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
bool atomic =
true; // suboptimal, but without atomic it doesn't pass the tests
Tensor gradOutput_ = gradOutput.contiguous();
const at::Tensor gradOutput_ = gradOutput.contiguous();
const at::Tensor indices_ = indices.contiguous();
const at::Tensor gradInput_c = gradInput.is_contiguous() ? gradInput : at::empty(gradInput.sizes(), gradInput.options());
if (input.ndimension() == 3) {
int64_t sizeD = input.size(0);
Expand All @@ -334,17 +346,17 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
gradInput.zero_();
gradInput_c.zero_();
AT_DISPATCH_FLOATING_TYPES_AND2(
kHalf,
kBFloat16,
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>();
int64_t* indices_data = indices_.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
Expand Down Expand Up @@ -393,7 +405,7 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
int64_t osizeH = gradOutput_.size(2);
int64_t osizeW = gradOutput_.size(3);
gradInput.zero_();
gradInput_c.zero_();
// bool atomic = (isizeH%osizeH != 0) || (isizeW%osizeW != 0);
Expand All @@ -403,9 +415,9 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
input.scalar_type(),
"adaptive_max_pool2d_backward_cuda",
[&] {
scalar_t* gradInput_data = gradInput.data_ptr<scalar_t>();
scalar_t* gradInput_data = gradInput_c.data_ptr<scalar_t>();
scalar_t* gradOutput_data = gradOutput_.data_ptr<scalar_t>();
int64_t* indices_data = indices.data_ptr<int64_t>();
int64_t* indices_data = indices_.data_ptr<int64_t>();
// cuda blocks & threads:
int blocksH = (int)(16L / sizeD);
Expand Down Expand Up @@ -446,6 +458,10 @@ TORCH_IMPL_FUNC(adaptive_max_pool2d_backward_out_cuda)
}
});
}
if (!gradInput.is_contiguous()) {
gradInput.copy_(gradInput_c);
}
}
} // at::native
} // at
1 change: 0 additions & 1 deletion test/test_nn.py
Expand Up @@ -14622,7 +14622,6 @@ def test_upsamplingBilinear2d(self, device):

self.assertEqual(a_cuda.grad, a_cpu.grad)

@onlyCPU
@dtypes(torch.float, torch.double)
def test_adaptive_pooling_max_nhwc(self, device, dtype):
def helper(n, c, h, w, output_height, output_width, contig):
Expand Down

0 comments on commit 302ee7b

Please sign in to comment.