Skip to content

Commit

Permalink
add cuda support based on following references
Browse files Browse the repository at this point in the history
  • Loading branch information
lim0606 committed Mar 24, 2017
1 parent 7772e41 commit 3c18086
Show file tree
Hide file tree
Showing 6 changed files with 678 additions and 2 deletions.
16 changes: 14 additions & 2 deletions script/build.py
Expand Up @@ -2,21 +2,33 @@
import torch
from torch.utils.ffi import create_extension

this_file = os.path.dirname(__file__)
#this_file = os.path.dirname(__file__)

sources = ['src/my_lib.c']
headers = ['src/my_lib.h']
defines = []
with_cuda = False

if torch.cuda.is_available():
print('Including CUDA code.')
sources += ['src/my_lib_cuda.c']
headers += ['src/my_lib_cuda.h']
defines += [('WITH_CUDA', None)]
with_cuda = True

this_file = os.path.dirname(os.path.realpath(__file__))
print(this_file)
extra_objects = ['src/my_lib_cuda_kernel.cu.o']
extra_objects = [os.path.join(this_file, fname) for fname in extra_objects]

ffi = create_extension(
'_ext.my_lib',
headers=headers,
sources=sources,
define_macros=defines,
relative_to=__file__,
with_cuda=with_cuda
with_cuda=with_cuda,
extra_objects=extra_objects
)

if __name__ == '__main__':
Expand Down
10 changes: 10 additions & 0 deletions script/make.sh
@@ -0,0 +1,10 @@
#!/usr/bin/env bash

CUDA_PATH=/usr/local/cuda/

cd src
echo "Compiling my_lib kernels by nvcc..."
nvcc -c -o my_lib_cuda_kernel.cu.o my_lib_cuda_kernel.cu -x cu -Xcompiler -fPIC -arch=sm_52

cd ../
python build.py
151 changes: 151 additions & 0 deletions script/src/my_lib_cuda.c
@@ -0,0 +1,151 @@
#include <THC/THC.h>
#include <stdbool.h>
#include <stdio.h>
#include "my_lib_cuda_kernel.h"

#define real float

// this symbol will be resolved automatically from PyTorch libs
extern THCState *state;

// Bilinear sampling is done in BHWD (coalescing is not obvious in BDHW)
// we assume BHWD format in inputImages
// we assume BHW(YX) format on grids

int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// THCudaTensor *output = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor");

int success = 0;
success = BilinearSamplerBHWD_updateOutput_cuda_kernel(output->size[2],
output->size[1],
output->size[0],
THCudaTensor_size(state, inputImages, 3),
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, output, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, output),
THCudaTensor_stride(state, output, 0),
THCudaTensor_stride(state, output, 3),
THCudaTensor_stride(state, output, 1),
THCudaTensor_stride(state, output, 2),
THCState_getCurrentStream(state));

//check for errors
if (!success) {
THError("aborting");
}
return 1;
}

int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// THCudaTensor *gradInputImages = (THCudaTensor *)luaT_checkudata(L, 4, "torch.CudaTensor");
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");

int success = 0;
success = BilinearSamplerBHWD_updateGradInput_cuda_kernel(gradOutput->size[2],
gradOutput->size[1],
gradOutput->size[0],
THCudaTensor_size(state, inputImages, 3),
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, gradOutput, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, gradInputImages),
THCudaTensor_stride(state, gradInputImages, 0),
THCudaTensor_stride(state, gradInputImages, 3),
THCudaTensor_stride(state, gradInputImages, 1),
THCudaTensor_stride(state, gradInputImages, 2),
THCudaTensor_data(state, gradGrids),
THCudaTensor_stride(state, gradGrids, 0),
THCudaTensor_stride(state, gradGrids, 3),
THCudaTensor_stride(state, gradGrids, 1),
THCudaTensor_stride(state, gradGrids, 2),
THCudaTensor_data(state, gradOutput),
THCudaTensor_stride(state, gradOutput, 0),
THCudaTensor_stride(state, gradOutput, 3),
THCudaTensor_stride(state, gradOutput, 1),
THCudaTensor_stride(state, gradOutput, 2),
THCState_getCurrentStream(state));

//check for errors
if (!success) {
THError("aborting");
}
return 1;
}

int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
THCudaTensor *gradGrids, THCudaTensor *gradOutput)
{
// THCState *state = getCutorchState(L);
// THCudaTensor *inputImages = (THCudaTensor *)luaT_checkudata(L, 2, "torch.CudaTensor");
// THCudaTensor *grids = (THCudaTensor *)luaT_checkudata(L, 3, "torch.CudaTensor");
// THCudaTensor *gradGrids = (THCudaTensor *)luaT_checkudata(L, 5, "torch.CudaTensor");
// THCudaTensor *gradOutput = (THCudaTensor *)luaT_checkudata(L, 6, "torch.CudaTensor");

int success = 0;
success = BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda_kernel(
gradOutput->size[2],
gradOutput->size[1],
gradOutput->size[0],
THCudaTensor_size(state, inputImages, 3),
THCudaTensor_size(state, inputImages, 1),
THCudaTensor_size(state, inputImages, 2),
THCudaTensor_size(state, gradOutput, 2),
THCudaTensor_data(state, inputImages),
THCudaTensor_stride(state, inputImages, 0),
THCudaTensor_stride(state, inputImages, 3),
THCudaTensor_stride(state, inputImages, 1),
THCudaTensor_stride(state, inputImages, 2),
THCudaTensor_data(state, grids),
THCudaTensor_stride(state, grids, 0),
THCudaTensor_stride(state, grids, 3),
THCudaTensor_stride(state, grids, 1),
THCudaTensor_stride(state, grids, 2),
THCudaTensor_data(state, gradGrids),
THCudaTensor_stride(state, gradGrids, 0),
THCudaTensor_stride(state, gradGrids, 3),
THCudaTensor_stride(state, gradGrids, 1),
THCudaTensor_stride(state, gradGrids, 2),
THCudaTensor_data(state, gradOutput),
THCudaTensor_stride(state, gradOutput, 0),
THCudaTensor_stride(state, gradOutput, 3),
THCudaTensor_stride(state, gradOutput, 1),
THCudaTensor_stride(state, gradOutput, 2),
THCState_getCurrentStream(state));

//check for errors
if (!success) {
THError("aborting");
}
return 1;
}
11 changes: 11 additions & 0 deletions script/src/my_lib_cuda.h
@@ -0,0 +1,11 @@
// Bilinear sampling is done in BHWD (coalescing is not obvious in BDHW)
// we assume BHWD format in inputImages
// we assume BHW(YX) format on grids

int BilinearSamplerBHWD_updateOutput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *output);

int BilinearSamplerBHWD_updateGradInput_cuda(THCudaTensor *inputImages, THCudaTensor *grids, THCudaTensor *gradInputImages,
THCudaTensor *gradGrids, THCudaTensor *gradOutput);

int BilinearSamplerBHWD_updateGradInputOnlyGrid_cuda(THCudaTensor *inputImages, THCudaTensor *grids,
THCudaTensor *gradGrids, THCudaTensor *gradOutput);

0 comments on commit 3c18086

Please sign in to comment.