Skip to content

Commit

Permalink
add a inner loop for index_select_grad_init() in index_select op when…
Browse files Browse the repository at this point in the history
… dealing with large-shape data (#41563) (#41669)
  • Loading branch information
FlyingQianMM committed Apr 13, 2022
1 parent aec47f8 commit 5d4980c
Show file tree
Hide file tree
Showing 4 changed files with 17 additions and 32 deletions.
15 changes: 5 additions & 10 deletions paddle/phi/kernels/funcs/gather.cu.h
Expand Up @@ -17,6 +17,7 @@ limitations under the License. */
#include <vector>
#include "paddle/fluid/memory/memcpy.h"
// TODO(paddle-dev): move gpu_primitives.h to phi
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/backends/gpu/gpu_launch_config.h"
#include "paddle/phi/common/place.h"
Expand Down Expand Up @@ -110,11 +111,8 @@ void GPUGather(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

GatherCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_src, p_index, p_output, index_size, slice_size);
Expand Down Expand Up @@ -155,11 +153,8 @@ void GPUGatherNd(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
if (grid > maxGridDimX) {
grid = maxGridDimX;
}
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

GatherNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(p_input,
g_input_dims,
Expand Down
16 changes: 7 additions & 9 deletions paddle/phi/kernels/funcs/scatter.cu.h
Expand Up @@ -15,6 +15,7 @@ limitations under the License. */
#pragma once
#include <unordered_set>
#include <vector>
#include "paddle/fluid/platform/device/gpu/gpu_launch_config.h"
#include "paddle/fluid/platform/device/gpu/gpu_primitives.h"
#include "paddle/phi/common/place.h"
#include "paddle/phi/core/dense_tensor.h"
Expand Down Expand Up @@ -155,9 +156,8 @@ void GPUScatterAssign(const phi::GPUContext& ctx,
// set block and grid num
int block = 512;
int64_t n = slice_size * index_size;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

// if not overwrite mode, init data
if (!overwrite) {
Expand Down Expand Up @@ -188,9 +188,8 @@ void GPUScatterGradForX(const phi::GPUContext& ctx,
int64_t block = 512;
int64_t n = slice_size * index_size;
int64_t height = (n + block - 1) / block;

int64_t max_grid_dimx = ctx.GetCUDAMaxGridDimSize()[0];
int64_t grid = height < max_grid_dimx ? height : max_grid_dimx;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

ScatterInitCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_index, p_output, index_size, slice_size);
Expand Down Expand Up @@ -230,9 +229,8 @@ void GPUScatterNdAdd(const phi::GPUContext& ctx,

int block = 512;
int64_t n = slice_size * remain_numel;
int64_t grid = (n + block - 1) / block;
unsigned int maxGridDimX = ctx.GetCUDAMaxGridDimSize()[0];
grid = grid > maxGridDimX ? maxGridDimX : grid;
dim3 grid = dim3((n + block - 1) / block);
paddle::platform::LimitGridDim(ctx, &grid);

ScatterNdCUDAKernel<T, IndexT><<<grid, block, 0, ctx.stream()>>>(
p_update,
Expand Down
16 changes: 4 additions & 12 deletions paddle/phi/kernels/gpu/index_select_grad_kernel.cu
Expand Up @@ -19,6 +19,7 @@
#include "paddle/phi/backends/gpu/gpu_info.h"
#include "paddle/phi/core/kernel_registry.h"
#include "paddle/phi/core/utils/data_type.h"
#include "paddle/phi/kernels/funcs/math_function.h"

DECLARE_bool(cudnn_deterministic);

Expand All @@ -35,7 +36,7 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
int64_t stride,
int64_t size,
int64_t delta) {
CUDA_KERNEL_LOOP(idx, N) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
Expand All @@ -45,15 +46,6 @@ __global__ void index_select_grad_cuda_kernel(const T* output_grad,
}
}

template <typename T>
__global__ void index_select_grad_init(T* input_grad, int64_t N) {
int64_t idx = blockIdx.x * blockDim.x + threadIdx.x;
if (idx >= N) {
return;
}
input_grad[idx] = 0.0;
}

template <typename T, typename Context>
void IndexSelectGradKernel(const Context& ctx,
const DenseTensor& x,
Expand Down Expand Up @@ -97,8 +89,8 @@ void IndexSelectGradKernel(const Context& ctx,
dim3 grid_dim = dim3((numel + block_dim - 1) / block_dim);
paddle::platform::LimitGridDim(ctx, &grid_dim);

index_select_grad_init<T><<<grid_dim, block_dim, 0, stream>>>(in_grad_data,
numel);
phi::funcs::SetConstant<phi::GPUContext, T> index_select_grad_init;
index_select_grad_init(ctx, x_grad, static_cast<T>(0));

if (FLAGS_cudnn_deterministic) {
VLOG(2) << "Run grad kernel of index_select with single thread.";
Expand Down
2 changes: 1 addition & 1 deletion paddle/phi/kernels/gpu/index_select_kernel.cu
Expand Up @@ -32,7 +32,7 @@ __global__ void index_select_cuda_kernel(const T* input,
int64_t stride,
int64_t size,
int64_t delta) {
CUDA_KERNEL_LOOP(idx, N) {
CUDA_KERNEL_LOOP_TYPE(idx, N, int64_t) {
int64_t pre_idx = idx / (stride * size);
int64_t dim_idx = idx % (stride * size) / stride;
IndexT src_dim_idx = index[dim_idx];
Expand Down

0 comments on commit 5d4980c

Please sign in to comment.