Skip to content
This repository has been archived by the owner on May 26, 2023. It is now read-only.

More examples #3

Open
fxia22 opened this issue Feb 17, 2017 · 9 comments
Open

More examples #3

fxia22 opened this issue Feb 17, 2017 · 9 comments

Comments

@fxia22
Copy link

fxia22 commented Feb 17, 2017

Hi pytorch team,

I am looking to port https://github.com/qassemoquab/stnbhwd to pytorch with effi. Do you know is it possible? Is the mechanism of writing extension for torch and pytorch similar or in other words, can I reuse some of the code from that repo. THanks.

@apaszke
Copy link
Contributor

apaszke commented Feb 17, 2017

Yes, it should be quite easy to reuse it. You'd only need to copy over the C files and change the functions to accept tensors as arguments instead of parsing them out of the Lua state. Then just use the package example and that should be it.

@fxia22
Copy link
Author

fxia22 commented Feb 17, 2017

Awesome, thank you! I will let you know how things go.

@fxia22
Copy link
Author

fxia22 commented Feb 17, 2017

@apaszke I am trying to get the data from CudaTensor, I changed the example library to the following but that gives me seg fault:

int my_lib_add_forward_cuda(THCudaTensor *input1, THCudaTensor *input2,
               THCudaTensor *output)
{
  if (!THCudaTensor_isSameSizeAs(state, input1, input2))
    return 0;
  float * input_data = THCudaTensor_data(state, input1);
  printf("data %f\n", input_data[0]);
  THCudaTensor_resizeAs(state, output, input1);
  THCudaTensor_cadd(state, output, input1, 1.0, input2);
  return 1;
}

I need similar operations for spatial transformer network to work with CUDA (cpu version already works). Can you share with me how to do this extraction? Thanks in advance.

@fxia22
Copy link
Author

fxia22 commented Feb 18, 2017

I guess my question is about how to reuse cuda code. When I attempted to do so, it tells me threadId.x is not defined.

@soumith
Copy link
Member

soumith commented Feb 18, 2017

you cannot printf a cuda pointer, it will segfault. Maybe you can lightly read the CUDA programming guide: docs.nvidia.com/cuda/cuda-c-programming-guide

@apaszke
Copy link
Contributor

apaszke commented Feb 18, 2017

Can't you just copy the code from the original repo? You shouldn't need to change any code that computes the function, only change the argument parsing.

@fxia22
Copy link
Author

fxia22 commented Feb 19, 2017

Thanks for your reply.

@apaszke Yes, I finished the CPU version porting and it was quite intuitive. And I read the CUDA programming guide. But how can I build a .cu extension with extension-ffi? I am able to use some torch CUDA functions like THCudaTensor_cadd, but how can I write my own CUDA functions?

For example, when I try to write on own add function, it gives me this error:

/home/fei/Development/extension-ffi/script/src/my_lib_cuda.c: In function ‘VecAdd’:
/home/fei/Development/extension-ffi/script/src/my_lib_cuda.c:9:64: error: ‘threadIdx’ undeclared (first use in this function)
 __global__ void VecAdd(float* A, float* B, float* C) { int i = threadIdx.x; C[i] = A[i] + B[i];  }
                                                                ^

@mattmacy
Copy link

@fxia22 torch.utils.ffi doesn't appear to have any knowledge of nvcc or .cu. I think you need to build your cuda sources separately (see the example Makefiles that come with the CUDA SDK) and then add the built object(s) to 'extra_objects' through kwargs when creating an extension.

See:
https://docs.python.org/3/distutils/apiref.html#distutils.core.Extension

@fxia22
Copy link
Author

fxia22 commented Mar 1, 2017

@mattmacy Thanks, I will give it a shot!

Sign up for free to subscribe to this conversation on GitHub. Already have an account? Sign in.
Labels
None yet
Projects
None yet
Development

No branches or pull requests

4 participants