Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fails to build against CUDA9 #492

Open
eigengrau opened this issue May 19, 2019 · 1 comment
Open

Fails to build against CUDA9 #492

eigengrau opened this issue May 19, 2019 · 1 comment

Comments

@eigengrau
Copy link

It looks like CUDA9 deprecates __shfl and __any. I was able to compile using the following quick&dirty patch:

--- LookupTable.cu	2019-05-18 11:03:38.615935768 +0200
+++ LookupTable.cu	2019-05-18 11:08:42.189278728 +0200
@@ -6,54 +6,54 @@
 #include <thrust/execution_policy.h>
 #include <thrust/iterator/constant_iterator.h>
 #include <thrust/transform_reduce.h>
 #if CUDA_VERSION >= 7000
 #include <thrust/system/cuda/execution_policy.h>
 #endif
 #include <thrust/unique.h>
 #include "THCHalf.h"
 #include "THCHalfAutoNumerics.cuh"
 #include "THCTensorSort.cuh"
+#define FULL_MASK 0xffffffff
-
 const int WARP_SIZE = 32;
 
 __device__ __forceinline__ bool warpHasCollision(int val)
 {
   // Compare our value to the values stored in the next 16 lanes,
   // wrapping around at 32. If any pair of values is the same than
   // there is a collision in the warp.
   bool dup = 0;
   const int laneId = threadIdx.x % 32;
 
 #if __CUDA_ARCH__ >= 300
 
   #pragma unroll
   for (int i = 1; i <= 16; i++)
   {
+    dup |= (__shfl_sync(FULL_MASK, val, (laneId + i) % 32) == val);
-    dup |= (__shfl(val, (laneId + i) % 32) == val);
   }
 
 #else
 
   volatile __shared__ int values[128];
   values[threadIdx.x] = val;
   const int offset = threadIdx.x - laneId;
 
   #pragma unroll
   for (int i = 1; i <= 16; i++)
   {
     dup |= (values[offset + ((laneId + i) % 32)] == val);
   }
 
 #endif
 
+  return __any_sync(FULL_MASK, dup) != 0;
-  return __any(dup) != 0;
 }
 
 template <typename Dtype>
 __global__ void cunn_LookupTable_accGradParametersKernelByFeature(
   long *input, Dtype *gradOutput, Dtype *gradWeight, Dtype scale, ptrdiff_t numel,
   long stride, int paddingValue) {
 
   const int featureDim = blockIdx.x * 4 + threadIdx.x / 32;
   if (featureDim >= stride) {
     return;
@eigengrau eigengrau changed the title CUDA9: identifier "__shfl" is undefined Fails to build on CUDA9 May 19, 2019
@eigengrau eigengrau changed the title Fails to build on CUDA9 Fails to build against CUDA9 May 19, 2019
@everdom
Copy link

everdom commented Sep 9, 2019

I had the same problem when I built against CUDA10.1, I was able to compile using your patch, but would it cause other runtime problems? That was what I really worried about.

ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfmLin1EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfmLi2EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfmLi1EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfmLin2EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfjLin1EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfjLi2EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfjLi1EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
ptxas warning : Value of threads per SM for entry Z24THNN_CudaHalfLSTMForwardI6__halfjLin2EEv10TensorInfoIT_T0_ES4_S4_S4_S4_S4_S4_S3_S3 is out of range. .minnctapersm will be ignored
[ 26%] Building NVCC (Device) object lib/THCUNN/CMakeFiles/THCUNN.dir/THCUNN_generated_LookupTableBag.cu.o
/tmp/luarocks_cunn-scm-1-5240/cunn/lib/THCUNN/LookupTable.cu(32): error: identifier "__shfl" is undefined

/tmp/luarocks_cunn-scm-1-5240/cunn/lib/THCUNN/LookupTable.cu(49): warning: function "__any"
/usr/local/cuda/include/device_atomic_functions.h(178): here was declared deprecated ("__any() is not valid on compute_70 and above, and should be replaced with __any_sync().To continue using __any(), specify virtual architecture compute_60 when targeting sm_70 and above, for example, using the pair of compiler options: -arch=compute_60 -code=sm_70.")

[ 28%] Building NVCC (Device) object lib/THCUNN/CMakeFiles/THCUNN.dir/THCUNN_generated_MSECriterion.cu.o
[ 29%] Building NVCC (Device) object lib/THCUNN/CMakeFiles/THCUNN.dir/THCUNN_generated_MarginCriterion.cu.o
1 error detected in the compilation of "/tmp/tmpxft_00002ab6_00000000-6_LookupTable.cpp1.ii".
CMake Error at THCUNN_generated_LookupTable.cu.o.Release.cmake:279 (message):
Error generating file
/tmp/luarocks_cunn-scm-1-5240/cunn/build/lib/THCUNN/CMakeFiles/THCUNN.dir//./THCUNN_generated_LookupTable.cu.o

lib/THCUNN/CMakeFiles/THCUNN.dir/build.make:175: recipe for target 'lib/THCUNN/CMakeFiles/THCUNN.dir/THCUNN_generated_LookupTable.cu.o' failed
make[2]: *** [lib/THCUNN/CMakeFiles/THCUNN.dir/THCUNN_generated_LookupTable.cu.o] Error 1

ashwin2802 added a commit to ashwin2802/cunn that referenced this issue Mar 6, 2021
kyoto7250 added a commit to kyoto7250/cunn that referenced this issue Feb 22, 2023
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

2 participants