Skip to content

Commit

Permalink
Update base for Update on "[NVFuser] Upstream push 0907"
Browse files Browse the repository at this point in the history
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

- codegen improvement:
i. improved view support on pointwise and transpose scheduler
ii. grouped grid welford added for better outer-norm grid persistence in normalization

- misc:
i. new composite ops added: variance_mean , arange, 
ii. fixes misaligned address for transpose scheduler
iii. refactor on separation of compilation API from execution API to prepare us for async compilation
iv. double type support on expression evaluator
v. PYTORCH_NVFUSER_DUMP refactor to save PTX and CUBIN

Commits that's in this PR from the devel branch:
```
89330aa Tensor factories must set the output shape as its input (#1939)
b2fd01e arange support (#1933)
56c00fd Double support on all expression evaluators (#1937)
371f282 Improve trivial reduction merge support (#1931)
1d0c267 Test `rand` in a fusion with zero tensor input (#1932)
0dab160 Fix softmax bwd sizes. (#1890)
ef98f36 Fix a bug (#1936)
63132a0 Propagate permissive mapping information into indexing pass (#1929)
b4ac2c8 Map IterationDomains through view operations. (#1919)
c0a187a do not use deprecated functions (#1935)
88de85e Upstream cherry pick fixes 0811 (#1934)
b247dcf Separate kernel compilation API from kernel execution API (#1914)
b34e3b9 Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924)
14a53e6 Nullary RNGOp (#1892)
3c3c89e Misc fixes/tuning for transpose scheduler (#1912)
20cf109 Grouped grid welford (#1921)
6cf7eb0 Transpose scheduler small dim sizes better support (#1910)
9341ea9 Disabled ViewPersistentShmoo sizes that results in NAN (#1922)
057237f Fix CUDA driver error: misaligned address for transpose scheduler  (#1918)
3fb3d80 Add variance_mean function using Welford (#1907)
98febf6 Remove DisableOption::UnrollWithRng (#1913)
ee8ef33 Minor fix for the debug interface of using PTX directly (#1917)
6e8f953 Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916)
5eefa9a dopt is only available since nvrtc 11.7 (#1915)
2ec8fc7 Kill computeAtBetween (#1911)
d0d106a Improve view support on pointwise and transpose scheduler (#1906)
e71e1ec Fix name clash of RNG with shared memory (#1904)
3381793 Fix mutator and sameAs for expanded IterDomain (#1902)
```

RUN_TORCHBENCH: nvfuser

Differential Revision: [D39324552](https://our.internmc.facebook.com/intern/diff/D39324552)

[ghstack-poisoned]
  • Loading branch information
jjsjann123 committed Sep 21, 2022
2 parents 9875376 + 2f4a517 commit 111fe61
Show file tree
Hide file tree
Showing 160 changed files with 4,013 additions and 1,248 deletions.
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/vision.txt
@@ -1 +1 @@
a4f53308b2d0f1aa9191686e326f45c26053f686
841b9a19a5ae6fb92b517438061b9846153cc0c6
2 changes: 1 addition & 1 deletion .github/ci_commit_pins/xla.txt
@@ -1 +1 @@
4dec902617aea14ca4013e402eea56e92701cac9
307af4313d2b0b0236618ef837959a41068cc272
6 changes: 4 additions & 2 deletions aten/src/ATen/DLConvertor.cpp
Expand Up @@ -252,8 +252,10 @@ Tensor fromDLPack(const DLManagedTensor* src) {
Device device = getATenDevice(src->dl_tensor.device);
ScalarType stype = toScalarType(src->dl_tensor.dtype);
auto deleter = [src](void* self) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
src->deleter(const_cast<DLManagedTensor*>(src));
if (src->deleter) {
// NOLINTNEXTLINE(cppcoreguidelines-pro-type-const-cast)
src->deleter(const_cast<DLManagedTensor*>(src));
}
};
if (!src->dl_tensor.strides) {
return at::from_blob(src->dl_tensor.data,
Expand Down
42 changes: 21 additions & 21 deletions aten/src/ATen/SparseCsrTensorImpl.cpp
Expand Up @@ -8,22 +8,10 @@
#include <ATen/native/Resize.h>

namespace at {
namespace {
DeviceType SparseCsrTensorSetToDeviceType(DispatchKeySet key_set) {
if (key_set.has(DispatchKey::SparseCsrCPU)) {
return kCPU;
} else if (key_set.has(DispatchKey::SparseCsrCUDA)) {
return kCUDA;
} else {
TORCH_CHECK(false,
"Cannot construct SparseCsrTensor with non-sparse tensor type ID ",
key_set);
}
}
} // namespace

SparseCsrTensorImpl::SparseCsrTensorImpl(
at::DispatchKeySet key_set,
at::Device device,
at::Layout layout,
const caffe2::TypeMeta data_type)
: SparseCsrTensorImpl(
Expand All @@ -32,19 +20,19 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.device(device)
.dtype(ScalarType::Int)) // crow_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.device(device)
.dtype(ScalarType::Int)) // col_indices
,
at::empty(
{0},
at::initialTensorOptions()
.device(SparseCsrTensorSetToDeviceType(key_set))
.device(device)
.dtype(data_type)) // values
,
layout
Expand All @@ -66,15 +54,24 @@ SparseCsrTensorImpl::SparseCsrTensorImpl(
TORCH_WARN_ONCE("Sparse ", at::sparse_csr::layoutToString(layout_, /*upper=*/true), " tensor support is in beta state. "
"If you miss a functionality in the sparse tensor support, please submit a feature request "
"to https://github.com/pytorch/pytorch/issues.");

TORCH_INTERNAL_ASSERT(((key_set.has(DispatchKey::SparseCsrCPU) && device().type() == kCPU)
|| (key_set.has(DispatchKey::SparseCsrCUDA) && device().type() == kCUDA)),
"Inconsistent key_set (=", key_set, ") and device (=", device(), ")");

set_storage_access_should_throw();
is_non_overlapping_and_dense_ = false;
set_custom_sizes_strides(SizesStridesPolicy::CustomStrides);
// TODO: If this check ever shows up as a bottleneck, which is unlikely given that
// comparing devices only involves comparing the type and index (two integers), we
// can move this to a DEBUG only assert. Until then this confirms and maintains a
// crucial invariance.
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
TORCH_INTERNAL_ASSERT(values_.device() == device(),
"Values and compressed sparse tensor instance need to have the same device.");
}

const char* SparseCsrTensorImpl::tensorimpl_type_name() const {
Expand Down Expand Up @@ -183,7 +180,6 @@ void SparseCsrTensorImpl::set_member_tensors(
") must match dtype of sparse tensor (",
typeMetaToScalarType(dtype()),
")");

crow_indices_ = crow_indices;
col_indices_ = col_indices;
values_ = values;
Expand All @@ -194,8 +190,12 @@ void SparseCsrTensorImpl::set_member_tensors(
// comparing devices only involves comparing the type and index (two integers), we
// can move this to a DEBUG only assert. Until then this confirms and maintains a
// crucial invariance.
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and crow_indices need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and col_indices need to be on the same device.");
TORCH_CHECK(values_.device() == crow_indices_.device(), "Values and ",
at::sparse_csr::compressedIndicesName(layout_), " need to be on the same device.");
TORCH_CHECK(values_.device() == col_indices_.device(), "Values and ",
at::sparse_csr::plainIndicesName(layout_), " need to be on the same device.");
TORCH_CHECK(values_.device() == device(),
"Values and compressed tensor instance need to be on the same device.");
}

IntArrayRef SparseCsrTensorImpl::strides_custom() const {
Expand Down
6 changes: 3 additions & 3 deletions aten/src/ATen/SparseCsrTensorImpl.h
Expand Up @@ -3,7 +3,6 @@
#include <ATen/Tensor.h>
#include <c10/core/TensorImpl.h>
#include <c10/util/Exception.h>

namespace at {

// Struct implementing a sparse CSR tensor. It uses three 1-D tensors for
Expand Down Expand Up @@ -33,6 +32,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
public:
explicit SparseCsrTensorImpl(
at::DispatchKeySet,
at::Device device,
Layout layout,
const caffe2::TypeMeta);

Expand Down Expand Up @@ -110,7 +110,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
const c10::VariableVersion& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
key_set(), layout_impl(), dtype());
key_set(), device(), layout_impl(), dtype());
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
Expand All @@ -130,7 +130,7 @@ struct TORCH_API SparseCsrTensorImpl : public TensorImpl {
c10::VariableVersion&& version_counter,
bool allow_tensor_metadata_change) const override {
auto impl = c10::make_intrusive<SparseCsrTensorImpl>(
key_set(), layout_impl(), dtype());
key_set(), device(), layout_impl(), dtype());
copy_tensor_metadata(
/*src_impl=*/this,
/*dest_impl=*/impl.get(),
Expand Down
2 changes: 1 addition & 1 deletion aten/src/ATen/TensorIterator.cpp
Expand Up @@ -1241,7 +1241,7 @@ void TensorIteratorBase::compute_shape(const TensorIteratorConfig& config) {

void TensorIteratorBase::compute_strides(const TensorIteratorConfig& config) {
for (auto& op : operands_) {
if (op.tensor_base().defined()) {
if (op.tensor_base().defined() && !op.will_resize) {
IntArrayRef original_shape = config.static_shape_ ? shape_ : op.tensor_base().sizes();
auto original_stride = op.tensor_base().strides();
auto element_size_in_bytes = op.tensor_base().element_size();
Expand Down
2 changes: 2 additions & 0 deletions aten/src/ATen/cuda/CUDAContext.h
Expand Up @@ -72,6 +72,8 @@ TORCH_CUDA_CPP_API Allocator* getCUDADeviceAllocator();
TORCH_CUDA_CPP_API cusparseHandle_t getCurrentCUDASparseHandle();
TORCH_CUDA_CPP_API cublasHandle_t getCurrentCUDABlasHandle();

TORCH_CUDA_CPP_API void clearCublasWorkspaces();

#ifdef CUDART_VERSION
TORCH_CUDA_CPP_API cusolverDnHandle_t getCurrentCUDASolverDnHandle();
#endif
Expand Down
55 changes: 55 additions & 0 deletions aten/src/ATen/cuda/CublasHandlePool.cpp
@@ -1,7 +1,18 @@
#include <ATen/cuda/CUDAContext.h>
#include <ATen/cuda/detail/DeviceThreadHandles.h>

#include <c10/cuda/CUDACachingAllocator.h>

#include <regex>

namespace at { namespace cuda {

static std::map<std::tuple<void *, void *>, at::DataPtr> cublas_handle_stream_to_workspace;

void clearCublasWorkspaces() {
cublas_handle_stream_to_workspace.clear();
}

namespace {

void createCublasHandle(cublasHandle_t *handle) {
Expand All @@ -25,6 +36,40 @@ using CuBlasPoolType = DeviceThreadHandlePool<cublasHandle_t, createCublasHandle

} // namespace

size_t parseChosenWorkspaceSize() {
const char * val = getenv("CUBLAS_WORKSPACE_CONFIG");
if (val) {
size_t total_size = 0;
const std::string config(val);
std::regex exp(":([0-9]+):([0-9]+)");
std::sregex_iterator next(config.begin(), config.end(), exp);
std::sregex_iterator end;
if (next == end) {
TORCH_WARN("Could not parse CUBLAS_WORKSPACE_CONFIG, using default workspace size of 4096.");
return 4096 * 1024;
}
while (next != end) {
std::smatch match = *next;
TORCH_CHECK(match.size() == 3, "Expected CUBLAS_WORKSPACE_SPACE_CONFIG match of size 3 (Format :SIZE:COUNT)");
size_t curr_size = (size_t) std::stoi(match.str(1));
total_size += curr_size * 1024;
next++;
}
return total_size;
} else /* :4096:8 */ {
return 4096 * 1024;
}
}

size_t getChosenWorkspaceSize() {
static size_t pool_size = parseChosenWorkspaceSize();
return pool_size;
}

at::DataPtr getNewWorkspace() {
return c10::cuda::CUDACachingAllocator::get()->allocate(getChosenWorkspaceSize());
}

cublasHandle_t getCurrentCUDABlasHandle() {
int device;
AT_CUDA_CHECK(cudaGetDevice(&device));
Expand All @@ -47,6 +92,16 @@ cublasHandle_t getCurrentCUDABlasHandle() {
auto handle = myPoolWindow->reserve(device);
auto stream = c10::cuda::getCurrentCUDAStream();
TORCH_CUDABLAS_CHECK(cublasSetStream(handle, stream));
#if !defined(USE_ROCM) && CUDA_VERSION >= 11000
// cublasSetWorkspace not available on CUDA 10.2
cudaStream_t _stream = stream;
auto key = std::make_tuple(static_cast<void *>(handle), static_cast<void *>(_stream));
if (cublas_handle_stream_to_workspace.find(key) == cublas_handle_stream_to_workspace.end()) {
auto workspace_ptr = getNewWorkspace();
cublas_handle_stream_to_workspace[key] = std::move(workspace_ptr);
}
TORCH_CUDABLAS_CHECK(cublasSetWorkspace(handle, cublas_handle_stream_to_workspace[key].get(), getChosenWorkspaceSize()));
#endif
#if defined(CUDA_VERSION) && CUDA_VERSION >= 11000
// On CUDA >= 11, and architecture >= Ampere, cuBLAS can use TF32 to speedup
// FP32 data type calculations based on the value of the allow_tf32 flag.
Expand Down
10 changes: 8 additions & 2 deletions aten/src/ATen/native/BatchLinearAlgebra.cpp
Expand Up @@ -2567,6 +2567,10 @@ Tensor& ormqr_out(const Tensor& input, const Tensor& tau, const Tensor& other, b
left_size_condition,
"] must be equal to input.shape[-2]");

TORCH_CHECK(
tau.size(-1) <= input.size(-1),
"torch.ormqr: tau.shape[-1] must be less than or equal to input.shape[-1]");

TORCH_CHECK(
input.dim() - tau.dim() == 1,
"torch.ormqr: ",
Expand Down Expand Up @@ -3512,14 +3516,16 @@ static void linalg_lstsq_out_info(
at::sum_out(residuals, raw_residuals, /*dim=*/-2, /*keepdim=*/false, /*dtype*/real_dtype);
}
}
solution = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
auto solution_view = solution.narrow(/*dim=*/-2, /*start=*/0, /*length*/n);
// manually restride original
solution.set_(solution.storage(), solution_view.storage_offset(), solution_view.sizes(), solution_view.strides());
if (m == 0) {
solution.zero_();
}

// for 1-dimensional 'other', we need to squeeze the solution after "apply_lstsq"
if (vector_case) {
solution = solution.squeeze_(-1);
solution.squeeze_(-1);
}
}

Expand Down
32 changes: 32 additions & 0 deletions aten/src/ATen/native/ComparisonUtils.cpp
@@ -0,0 +1,32 @@
#include <ATen/core/TensorBase.h>
#include <algorithm>
#include <vector>
#include <ATen/core/TensorBody.h>
#include <c10/util/OptionalArrayRef.h>

namespace at {

class Tensor;

namespace native {

template<typename O, typename C>
void _assert_match(const O& original, const C& compared, const std::string& name) {
if (compared) {
bool equal = (original == compared.value());
if (!equal) {
std::stringstream msg;
msg << "Tensor " << name << " mismatch!";
AT_ASSERT(equal, msg.str());
}
}
}

void _assert_tensor_metadata(at::Tensor const& tensor, at::OptionalIntArrayRef sizes, at::OptionalIntArrayRef strides, c10::optional<c10::ScalarType> dtype) {
_assert_match(tensor.sizes(), sizes, "sizes");
_assert_match(tensor.strides(), strides, "strides");
_assert_match(tensor.dtype(), dtype, "dtype");
}

}
} // namespace at::native
1 change: 1 addition & 0 deletions aten/src/ATen/native/Convolution.cpp
Expand Up @@ -1309,6 +1309,7 @@ at::Tensor _convolution(
int64_t dim = k - 2;

TORCH_CHECK(dim > 0, "weight should have at least three dimensions");
TORCH_CHECK(groups_ > 0, "non-positive groups is not supported");

ConvParams params;
params.stride = expand_param_if_needed(stride_, "stride", dim);
Expand Down
7 changes: 4 additions & 3 deletions aten/src/ATen/native/LinearAlgebra.cpp
Expand Up @@ -48,14 +48,16 @@ namespace detail {
namespace meta {

#define ADDMM_META() \
TORCH_CHECK(self.scalar_type() == mat2.scalar_type(), "self and mat2 must have the same dtype"); \
TORCH_CHECK(mat1.scalar_type() == mat2.scalar_type(), "mat1 and mat2 must have the same dtype"); \
TORCH_CHECK(mat1.dim() == 2, "mat1 must be a matrix, got ", mat1.dim(), "-D tensor"); \
TORCH_CHECK(mat2.dim() == 2, "mat2 must be a matrix, got ", mat2.dim(), "-D tensor"); \
TORCH_CHECK( \
mat1.sizes()[1] == mat2.sizes()[0], "mat1 and mat2 shapes cannot be multiplied (", \
mat1.sizes()[0], "x", mat1.sizes()[1], " and ", mat2.sizes()[0], "x", mat2.sizes()[1], ")"); \
\
auto names = at::namedinference::propagate_names_for_addmm(mat1, mat2, self); \
set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, self.options(), names);
set_output_raw_strided(0, {mat1.sizes()[0], mat2.sizes()[1]}, {}, mat1.options(), names);

TORCH_META_FUNC(addmm)(const Tensor& self, const Tensor& mat1, const Tensor& mat2, const Scalar& beta, const Scalar& alpha) {
ADDMM_META();
Expand Down Expand Up @@ -2285,8 +2287,7 @@ void compute_T18_scale_square(
for (const auto i : c10::irange(mexp_scaled.size(0))) {
auto s_val = s_cpu.select(0, i).template item<int64_t>();
auto mexp = mexp_scaled.select(0, i);
for (const auto p : c10::irange(s_val)) {
(void)p; //Suppress unused variable warning
for (const auto p C10_UNUSED : c10::irange(s_val)) {
mexp = at::matmul(mexp, mexp);
}
mexp_out.select(0, i).copy_(mexp);
Expand Down
3 changes: 1 addition & 2 deletions aten/src/ATen/native/LinearAlgebraUtils.h
Expand Up @@ -241,8 +241,7 @@ void batch_iterator_with_broadcasting(const Tensor& a, const Tensor& b, const fu
auto* b_batch_idx_ptr = data[0];
auto* a_batch_idx_ptr = data[1];

for (const auto elem : c10::irange(nelems)) {
(void)elem; //Suppress unused variable warning
for (const auto elem C10_UNUSED : c10::irange(nelems)) {
auto b_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(b_batch_idx_ptr);
auto a_curr_linear_batch_idx = *reinterpret_cast<int64_t*>(a_batch_idx_ptr);

Expand Down

0 comments on commit 111fe61

Please sign in to comment.