Skip to content

Commit

Permalink
Wrap cub in its own namespace (#55292) (#61605)
Browse files Browse the repository at this point in the history
Summary:
Tentative fix for #55027.
Wraps cub import in its name space so that static variables used by cub and thrust don't conflict if they end up in the different libraries when torch is built with BUILD_SPLIT_CUDA. cub variables end up in their own namespace, thrust variables are unwrapped, so they don't clash.
This also allows extensions to use cub without wrapping it (thrust will still be problematic). The solution to allowing extensions to use thrust is to stop using thrust in pytorch completely.
Now importing cub and importing thrust cannot coexist, so I had to move nonzero to its own file, and remove reliance on thrust functions for it. Nonzero now uses cub only.
Also, we cannot selectively import just some of cub headers, we are forced to import `cub/cub.cuh`, which is not great.
Caffe2 ops using cub are not touched (there are too many), so mixing caffe2 and torch will (can) still result in the same bug. We are moving towards disabling c2 ops, so I think this is fine.
Still, even with that compiler (correctly) warns about redefinition of `CUB_NS_PREFIX` because including `ATen/ATen.h` transitively includes `thrust/complex.h` and that in turn includes original (empty) definition of `CUB_NS_PREFIX`. We probably can just ignore this warning. Here's an example warning:
```
In file included from /data/users/ngimel/pytorch/aten/src/ATen/native/cuda/Nonzero.cu:9:
/data/users/ngimel/pytorch/aten/src/ATen/cuda/CubUtils.cuh:4: warning: "CUB_NS_PREFIX" redefined
 #define CUB_NS_PREFIX namespace at{ namespace native{

In file included from /home/ngimel/local/cuda/include/thrust/system/cuda/config.h:76,
                 from /home/ngimel/local/cuda/include/thrust/system/cuda/detail/execution_policy.h:33,
                 from /home/ngimel/local/cuda/include/thrust/iterator/detail/device_system_tag.h:23,
                 from /home/ngimel/local/cuda/include/thrust/iterator/iterator_traits.h:111,
                 from /home/ngimel/local/cuda/include/thrust/detail/type_traits/pointer_traits.h:23,
                 from /home/ngimel/local/cuda/include/thrust/type_traits/is_contiguous_iterator.h:27,
                 from /home/ngimel/local/cuda/include/thrust/type_traits/is_trivially_relocatable.h:19,
                 from /home/ngimel/local/cuda/include/thrust/detail/complex/complex.inl:20,
                 from /home/ngimel/local/cuda/include/thrust/complex.h:1031,
                 from /data/users/ngimel/pytorch/c10/util/complex.h:9,
                 from /data/users/ngimel/pytorch/c10/core/ScalarType.h:4,
                 from /data/users/ngimel/pytorch/c10/core/Scalar.h:10,
                 from /data/users/ngimel/pytorch/build/aten/src/ATen/core/TensorBody.h:8,
                 from /data/users/ngimel/pytorch/aten/src/ATen/Tensor.h:3,
                 from /data/users/ngimel/pytorch/aten/src/ATen/Context.h:4,
                 from /data/users/ngimel/pytorch/aten/src/ATen/ATen.h:9,
                 from /data/users/ngimel/pytorch/aten/src/ATen/native/cuda/Nonzero.cu:1:
/home/ngimel/local/cuda/include/cub/util_namespace.cuh:43: note: this is the location of the previous definition
 #define CUB_NS_PREFIX

```
We will need a lint rule to prevent people from including `cub/cub.cuh`, because this will lead to #55027 reappearing again for some sequence of operations (and will lead to errors with cub code in extensions).
Also, for this to work reliably we'll need to make sure that everything calling cub ends up in only one of libtorch_cuda_cu or libtorch_cuda_cpp, otherwise even namespace won't help (there still will be same symbols in 2 libraries).

Upd: libtorch_cuda_cpp and libtorch_cuda_cu still contain the same symbols, which means that there exists a sequence of operations that will cause cache bug to reappear, so this is not a solution, we need to adjust file lists for BUILD_SPLITC_CUDA:
```
(pytorch) [ngimel@ ~/local/pytorch/build/lib] nm libtorch_cuda_cu.so | grep PerDeviceAttributeCache | c++filt
000000000c6bf808 u guard variable for at::native::cub::GetPerDeviceAttributeCache<at::native::cub::PtxVersionCacheTag>()::cache
000000000c600830 u guard variable for cub::GetPerDeviceAttributeCache<cub::PtxVersionCacheTag>()::cache
00000000018625e0 t at::native::cub::PerDeviceAttributeCache::DevicePayload at::native::cub::PerDeviceAttributeCache::operator()<at::native::cub::PtxVersion(int&)::{lambda(int&)https://github.com/pytorch/pytorch/issues/1}>(at::native::cub::PtxVersion(int&)::{lambda(int&)https://github.com/pytorch/pytorch/issues/1}&&, int)
00000000009ce630 t cub::PerDeviceAttributeCache::DevicePayload cub::PerDeviceAttributeCache::operator()<cub::PtxVersion(int&)::{lambda(int&)https://github.com/pytorch/pytorch/issues/1}>(cub::PtxVersion(int&)::{lambda(int&)https://github.com/pytorch/pytorch/issues/1}&&, int)
000000000c6bf820 u at::native::cub::GetPerDeviceAttributeCache<at::native::cub::PtxVersionCacheTag>()::cache
000000000c600840 u cub::GetPerDeviceAttributeCache<cub::PtxVersionCacheTag>()::cache
(pytorch) [ngimel@ ~/local/pytorch/build/lib] nm libtorch_cuda_cpp.so | grep PerDeviceAttributeCache | c++filt
0000000000ad2d98 u guard variable for at::native::cub::GetPerDeviceAttributeCache<at::native::cub::PtxVersionCacheTag>()::cache
0000000000ad2da0 u at::native::cub::GetPerDeviceAttributeCache<at::native::cub::PtxVersionCacheTag>()::cache
```
Upd2:
Moved TensorFactories.cu to torch_cuda_cu sources (see a change to caffe2/CMakeLists.txt), so now cub-related symbols are only in libtorch_cuda_cu. We'd need a test for that, any suggestions on how best to test it?
cc zasdfgbnm malfet

Pull Request resolved: #55292

Reviewed By: anjali411

Differential Revision: D27576442

Pulled By: ngimel

fbshipit-source-id: 1ef29503a342bb214794d34a42a47052092a66c1

Co-authored-by: Natalia Gimelshein <ngimel@fb.com>
  • Loading branch information
xiaoyu-work and ngimel committed Jul 23, 2021
1 parent 8f69e49 commit e0495a7
Show file tree
Hide file tree
Showing 8 changed files with 218 additions and 168 deletions.
3 changes: 3 additions & 0 deletions .github/workflows/lint.yml
Expand Up @@ -60,6 +60,9 @@ jobs:
run: |
set -eux
python torch/testing/check_kernel_launches.py |& tee ${GITHUB_WORKSPACE}/cuda_kernel_launch_checks.txt
- name: Ensure no direct cub include
run: |
(! git grep -I -no $'#include <cub/' -- ./aten ':(exclude)aten/src/ATen/cuda/CubUtils.cuh' || (echo "The above files have direct cub include; please include ATen/cuda/CubUtils.cuh instead and wrap your cub calls in at::native namespace if necessary"; false))
flake8-py3:
runs-on: ubuntu-18.04
Expand Down
10 changes: 10 additions & 0 deletions aten/src/ATen/cuda/CubUtils.cuh
@@ -0,0 +1,10 @@
#pragma once

// include cub in a safe manner
#undef CUB_NS_POSTFIX //undef to avoid redefinition warnings
#undef CUB_NS_PREFIX
#define CUB_NS_PREFIX namespace at{ namespace native{
#define CUB_NS_POSTFIX }}
#include <cub/cub.cuh>
#undef CUB_NS_POSTFIX
#undef CUB_NS_PREFIX
89 changes: 0 additions & 89 deletions aten/src/ATen/native/cuda/Indexing.cu
Expand Up @@ -17,10 +17,8 @@
#include <THC/THCThrustAllocator.cuh>
#include <thrust/execution_policy.h>
#include <thrust/sort.h>
#include <thrust/transform.h>
#include <THC/THCAtomics.cuh>

#include <cub/cub.cuh>

#include <c10/macros/Macros.h>

Expand Down Expand Up @@ -848,92 +846,5 @@ Tensor index_select_cuda(const Tensor& self, int64_t dim, const Tensor& index) {
return out;
}

template<typename T>
struct NonZeroOp
{
__host__ __device__ __forceinline__ bool operator()(const T& a) const {
return (a!=T(0));
}
};

template<typename scalar_t>
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
Tensor self_ = self.contiguous();
int N = self_.numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// compute number of nonzero elements
size_t temp_storage_bytes=0;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto num_nonzeros = allocator.allocate(sizeof(int));
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t*> itr(self_.data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
int num_nonzeros_h;
C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream));
//need to synchronize to make sure data is available on the host
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
//expected output size is num_nonzeros x ndim
//we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output)
//we are able to directly use passed output with this size and strides, and we can also (per contract)
//resize passed output with incorrect sizes anyway we want.
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous();
at::Tensor out_temp = need_to_copy ?
at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()),
out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) :
out.resize_({self.dim(), num_nonzeros_h});
//Scalars are expected to produce output of size (1,0), so we can't write to it
if (self.dim() > 0) {
cub::CountingInputIterator<int64_t> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,
out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr,
out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
if (num_nonzeros_h > 0 && self.dim() > 1){
int64_t div = 1;
auto thrust_allocator = THCThrustAllocator(globalContext().lazyInitCUDA());
for (int dim = self.dim()-1; dim >= 0; dim--){
int64_t dim_size = self.sizes()[dim];
thrust::transform(
thrust::cuda::par(thrust_allocator).on(stream),
thrust::device_ptr<int64_t>(out_temp.data_ptr<int64_t>()),
thrust::device_ptr<int64_t>(out_temp.data_ptr<int64_t>()) + num_nonzeros_h,
thrust::device_ptr<int64_t>(out_temp.data_ptr<int64_t>()) + num_nonzeros_h * dim,
[=] C10_HOST_DEVICE (const int64_t val) {return (val/div) % dim_size;}
);
div *= dim_size;
}
}
}
if (need_to_copy) {
out.copy_(out_temp.t());
} else {
//transpose out so it is correct size
Tensor out_ = out_temp.t();
out.set_(out_);
}
}

Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
file a support request");
TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype());
TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ",
out.device(), " and self on ", self.device());
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "nonzero_cuda",
[&] {nonzero_cuda_out_impl<scalar_t>(self, out);});
return out;
}

Tensor nonzero_cuda(const Tensor& self){
Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
return nonzero_out_cuda(out, self);
}


} // native
} // at
118 changes: 118 additions & 0 deletions aten/src/ATen/native/cuda/Nonzero.cu
@@ -0,0 +1,118 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDACachingAllocator.h>
#include <ATen/cuda/detail/KernelUtils.h>
#include <ATen/cuda/detail/OffsetCalculator.cuh> //for MAX_DIMS
#include <ATen/cuda/CubUtils.cuh>


namespace at {
namespace native {

namespace{
template<typename T>
struct NonZeroOp
{
__host__ __device__ __forceinline__ bool operator()(const T& a) const {
return (a!=T(0));
}
};

//TODO: actually support int64_t index_t
template<typename index_t>
struct TensorDims {
index_t sizes[MAX_DIMS];
};

template<typename index_t>
__global__ void write_indices(int64_t * inp, TensorDims<index_t> dims, int ndim, index_t n){
CUDA_KERNEL_LOOP(index, n) { // this assumed int (not int64_t) index
index_t div = 1;
int64_t idx_flat = inp[index];
for (int dim = ndim-1; dim >= 0; dim--){
auto dim_size = dims.sizes[dim];
inp[index + dim*n] = (idx_flat/div) % dim_size;
div *= dim_size;
}
}
}


} //anonymous namespace

template<typename scalar_t>
void nonzero_cuda_out_impl(const Tensor& self, Tensor& out){
Tensor self_ = self.contiguous();
int N = self_.numel();
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
// compute number of nonzero elements
size_t temp_storage_bytes=0;
auto& allocator = *c10::cuda::CUDACachingAllocator::get();
auto num_nonzeros = allocator.allocate(sizeof(int));
cub::TransformInputIterator<bool, NonZeroOp<scalar_t>, scalar_t*> itr(self_.data_ptr<scalar_t>(), NonZeroOp<scalar_t>());
cub::DeviceReduce::Sum(nullptr, temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
auto temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceReduce::Sum(temp_storage.get(), temp_storage_bytes, itr, (int*)num_nonzeros.get(), N, stream);
int num_nonzeros_h;
C10_CUDA_CHECK(cudaMemcpyAsync(&num_nonzeros_h, num_nonzeros.get(), sizeof(int), cudaMemcpyDeviceToHost, stream));
//need to synchronize to make sure data is available on the host
C10_CUDA_CHECK(cudaStreamSynchronize(stream));
//expected output size is num_nonzeros x ndim
//we are producing output with size {num_nonzeros, ndim} and strides {num_nonzeros, 1} (that is, transposed ndim x num_nonzeros output)
//we are able to directly use passed output with this size and strides, and we can also (per contract)
//resize passed output with incorrect sizes anyway we want.
//However, out with correct sizes and incorrect strides will have to be copied to from the intermediate we've produced.
bool need_to_copy = out.dim() == 2 && out.sizes()[0] == num_nonzeros_h && out.sizes()[1] == self.dim() && !out.t().is_contiguous();
at::Tensor out_temp = need_to_copy ?
at::native::empty_cuda({self.dim(), num_nonzeros_h}, optTypeMetaToScalarType(out.options().dtype_opt()),
out.options().layout_opt(), out.options().device_opt(), out.options().pinned_memory_opt()) :
out.resize_({self.dim(), num_nonzeros_h});
//Scalars are expected to produce output of size (1,0), so we can't write to it
if (self.dim() > 0) {
cub::CountingInputIterator<int64_t> counting_itr(0);
temp_storage_bytes = 0;
cub::DeviceSelect::Flagged(nullptr, temp_storage_bytes, counting_itr, itr,
out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
temp_storage = allocator.allocate(temp_storage_bytes);
cub::DeviceSelect::Flagged(temp_storage.get(), temp_storage_bytes, counting_itr, itr,
out_temp.data_ptr<int64_t>(), (int*)num_nonzeros.get(), N, stream);
if (num_nonzeros_h > 0 && self.dim() > 1){
TensorDims<int> dims;
for (int i=0; i<self.dim(); i++){
dims.sizes[i] = self.sizes()[i];
}
const int nthreads = 256;
const int nblocks = (num_nonzeros_h + nthreads -1)/nthreads;
write_indices<<<nblocks, nthreads, 0, stream>>>(out_temp.data_ptr<int64_t>(),
dims, self.dim(), num_nonzeros_h);
C10_CUDA_KERNEL_LAUNCH_CHECK();
}
}
if (need_to_copy) {
out.copy_(out_temp.t());
} else {
//transpose out so it is correct size
Tensor out_ = out_temp.t();
out.set_(out_);
}
}

Tensor& nonzero_out_cuda(Tensor& out, const Tensor& self){
TORCH_CHECK(self.numel() < std::numeric_limits<int>::max(), "nonzero is not supported for tensors with more than INT_MAX elements, \
file a support request");
TORCH_CHECK(out.dtype() == at::kLong, "Expected object of scalar type ", at::kLong, " as out, but got ", out.dtype());
TORCH_CHECK(self.device() == out.device(), "expected self and out to be on the same device, but got out on ",
out.device(), " and self on ", self.device());
TORCH_CHECK(self.dim() <= MAX_DIMS, "nonzero is not supported for tensor with more than ", MAX_DIMS, " dimensions");
AT_DISPATCH_ALL_TYPES_AND3(at::ScalarType::Bool, at::ScalarType::BFloat16, at::ScalarType::Half,
self.scalar_type(), "nonzero_cuda",
[&] {nonzero_cuda_out_impl<scalar_t>(self, out);});
return out;
}

Tensor nonzero_cuda(const Tensor& self){
Tensor out = at::native::empty_cuda({0}, kLong, self.options().layout_opt(), self.options().device_opt(), self.options().pinned_memory_opt());
return nonzero_out_cuda(out, self);
}
} //namespace::native
} //namespace::at
83 changes: 83 additions & 0 deletions aten/src/ATen/native/cuda/Randperm.cu
@@ -0,0 +1,83 @@
#include <ATen/ATen.h>
#include <ATen/cuda/CUDAApplyUtils.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <ATen/native/TensorFactories.h>
#include <ATen/cuda/CubUtils.cuh>

#include <limits>

namespace at {
namespace native {

Tensor& randperm_out_cuda(Tensor& result, int64_t n, c10::optional<Generator> generator) {
TORCH_CHECK(n >= 0, "n must be non-negative, got", n);
TORCH_CHECK(!generator.has_value() || (generator.has_value() && result.device() == generator->device()), "Expected a '", result.device(), "' generator device but found '", generator->device(), "'");
check_supported_max_int_with_precision(n, result);

result.resize_({n});

if (n < 30000) { // For small inputs, we offload it to CPU instead.
auto result_cpu = at::empty({n}, result.options().device(kCPU));
randperm_out(result_cpu, n, generator);
return result.copy_(result_cpu);
}

#if 0
// This if condition should never be true because if n >= 30000 and the tensor has a Half type,
// check_supported_max_int_with_precision should have reported an error. This snippet is commented out but left here
// for the sake of clarity, because Half in thrust is spotty, and we do not want future change unaware of this.
if (result.scalar_type() == at::ScalarType::Half) { // Half in thrust is spotty. Avoid.
auto result_float = at::empty({n}, initialTensorOptions().device(Device(DeviceType::CUDA)));
return result.copy_(randperm_out_cuda(result_float, n, generator));
}
#endif

// Generate random values for the keys array
AT_DISPATCH_ALL_TYPES(
result.scalar_type(), "randperm_out_cuda", [&] {
TORCH_CHECK(n <= std::numeric_limits<int>::max(),
"randperm of tensors larger than INT_MAX is not supported yet in pytorch");

auto keys = at::empty(result.sizes(), result.options()).random_(generator);
auto range = at::arange(n, result.options());
auto keys_tmp = at::empty_like(keys);

// shuffled_data points to the underlying data of the output tensor if the tensor is contiguous; otherwise it
// points to a new tensor.
Tensor shuffled;
scalar_t *shuffled_data;
if (result.is_contiguous()) {
shuffled_data = result.data_ptr<scalar_t>();
} else {
shuffled = at::empty(n, result.options());
shuffled_data = shuffled.data_ptr<scalar_t>();
}

// Use the sorted order of keys to rearrange the result array
size_t temp_storage_bytes = 0;

cub::DeviceRadixSort::SortPairs(
nullptr, temp_storage_bytes,
keys.data_ptr<scalar_t>(), keys_tmp.data_ptr<scalar_t>(),
range.data_ptr<scalar_t>(), shuffled_data, n,
0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream());
auto& allocator = *::c10::cuda::CUDACachingAllocator::get();
auto dataPtr = allocator.allocate(temp_storage_bytes);
cub::DeviceRadixSort::SortPairs(
dataPtr.get(), temp_storage_bytes,
keys.data_ptr<scalar_t>(), keys_tmp.data_ptr<scalar_t>(),
range.data_ptr<scalar_t>(), shuffled_data, n,
0, sizeof(scalar_t) * 8, at::cuda::getCurrentCUDAStream());

if (!result.is_contiguous()) {
result.copy_(shuffled);
}
}
);

return result;
}



}} // namespace at::native
2 changes: 1 addition & 1 deletion aten/src/ATen/native/cuda/ScanKernels.cu
Expand Up @@ -4,7 +4,7 @@
#include <THC/THCNumerics.cuh>
#include <ATen/cuda/CUDAContext.h>
#include <THC/THCGeneral.h>
#include <cub/device/device_scan.cuh>
#include <ATen/cuda/CubUtils.cuh>


namespace at { namespace native {
Expand Down

0 comments on commit e0495a7

Please sign in to comment.