Skip to content

Commit

Permalink
[NVFuser] Upstream push 0907
Browse files Browse the repository at this point in the history
Syncing nvfuser devel branch to upstream master. https://github.com/csarofeen/pytorch/

Codegen changes include:

- codegen improvement:
i. improved view support on pointwise and transpose scheduler
ii. grouped grid welford added for better outer-norm grid persistence in normalization

- misc:
i. new composite ops added: variance_mean , arange,
ii. fixes misaligned address for transpose scheduler
iii. refactor on separation of compilation API from execution API to prepare us for async compilation
iv. double type support on expression evaluator
v. PYTORCH_NVFUSER_DUMP refactor to save PTX and CUBIN

Commits that's in this PR from the devel branch:
```
89330aa Tensor factories must set the output shape as its input (#1939)
b2fd01e arange support (#1933)
56c00fd Double support on all expression evaluators (#1937)
371f282 Improve trivial reduction merge support (#1931)
1d0c267 Test `rand` in a fusion with zero tensor input (#1932)
0dab160 Fix softmax bwd sizes. (#1890)
ef98f36 Fix a bug (#1936)
63132a0 Propagate permissive mapping information into indexing pass (#1929)
b4ac2c8 Map IterationDomains through view operations. (#1919)
c0a187a do not use deprecated functions (#1935)
88de85e Upstream cherry pick fixes 0811 (#1934)
b247dcf Separate kernel compilation API from kernel execution API (#1914)
b34e3b9 Fix `ir_utils::hasBlockSync` + misc fixes in transpose scheduler (#1924)
14a53e6 Nullary RNGOp (#1892)
3c3c89e Misc fixes/tuning for transpose scheduler (#1912)
20cf109 Grouped grid welford (#1921)
6cf7eb0 Transpose scheduler small dim sizes better support (#1910)
9341ea9 Disabled ViewPersistentShmoo sizes that results in NAN (#1922)
057237f Fix CUDA driver error: misaligned address for transpose scheduler  (#1918)
3fb3d80 Add variance_mean function using Welford (#1907)
98febf6 Remove DisableOption::UnrollWithRng (#1913)
ee8ef33 Minor fix for the debug interface of using PTX directly (#1917)
6e8f953 Add PYTORCH_NVFUSER_DUMP options to save PTX and CUBIN (#1916)
5eefa9a dopt is only available since nvrtc 11.7 (#1915)
2ec8fc7 Kill computeAtBetween (#1911)
d0d106a Improve view support on pointwise and transpose scheduler (#1906)
e71e1ec Fix name clash of RNG with shared memory (#1904)
3381793 Fix mutator and sameAs for expanded IterDomain (#1902)
```

RUN_TORCHBENCH: nvfuser

ghstack-source-id: d0d88cff0c908b2f0ebf6defaab10bc3e7b437b5
Pull Request resolved: #84626
  • Loading branch information
jjsjann123 committed Sep 9, 2022
1 parent 8bd9fe3 commit c8a6173
Show file tree
Hide file tree
Showing 111 changed files with 8,278 additions and 3,257 deletions.
14 changes: 10 additions & 4 deletions benchmarks/cpp/nvfuser/heuristic_lookup.cpp
Expand Up @@ -99,12 +99,15 @@ static void LayerNormBackward_HeuristicLookup(

auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
runtime->getMaybeHeuristicsFor(aten_inputs);
runtime->getMaybeHeuristicsFor(args);
}
}

Expand Down Expand Up @@ -152,12 +155,15 @@ static void LayerNormForward_HeuristicLookup(

auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

for (auto _ : benchmark_state) {
// Setup (not included in the measurement)
runtime->getMaybeHeuristicsFor(aten_inputs);
runtime->getMaybeHeuristicsFor(args);
}
}

Expand Down
9 changes: 7 additions & 2 deletions benchmarks/cpp/nvfuser/shape_inference.cpp
Expand Up @@ -100,8 +100,11 @@ void LayerNormBackward_ShapeInference_Base(

auto runtime = getLayerBackwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

fec->profile(true);
fec->disableKernelLaunch();
Expand Down Expand Up @@ -172,8 +175,10 @@ void LayerNormForward_ShapeInferenceBase(
auto runtime = getLayerForwardNormRuntime(
std::move(fusion_ptr), fec, aten_inputs, shape, norm_shape);

KernelArgumentHolder args = KernelArgumentHolder::createKernelArgumentHolder(aten_inputs);

TORCH_INTERNAL_ASSERT(
runtime->getMaybeHeuristicsFor(aten_inputs).has_value());
runtime->getMaybeHeuristicsFor(args).has_value());

fec->profile(true);
fec->disableKernelLaunch();
Expand Down
32 changes: 16 additions & 16 deletions benchmarks/cpp/nvfuser/softmax_backward.cpp
Expand Up @@ -177,13 +177,13 @@ NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp32)

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp32)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp32)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -201,13 +201,13 @@ NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp16)

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp16)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Outer_fp16)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -225,13 +225,13 @@ NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp32)

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp32)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp32)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -249,13 +249,13 @@ NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp16)

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp16)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

NVFUSER_BENCHMARK_RUN(NvFuserScheduler_Softmax_BWD_Inner_fp16)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -275,13 +275,13 @@ BENCHMARK(Baseline_Softmax_BWD_Outer_fp32)

BENCHMARK(Baseline_Softmax_BWD_Outer_fp32)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

BENCHMARK(Baseline_Softmax_BWD_Outer_fp32)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -299,13 +299,13 @@ BENCHMARK(Baseline_Softmax_BWD_Outer_fp16)

BENCHMARK(Baseline_Softmax_BWD_Outer_fp16)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

BENCHMARK(Baseline_Softmax_BWD_Outer_fp16)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -323,13 +323,13 @@ BENCHMARK(Baseline_Softmax_BWD_Inner_fp32)

BENCHMARK(Baseline_Softmax_BWD_Inner_fp32)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

BENCHMARK(Baseline_Softmax_BWD_Inner_fp32)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand All @@ -347,13 +347,13 @@ BENCHMARK(Baseline_Softmax_BWD_Inner_fp16)

BENCHMARK(Baseline_Softmax_BWD_Inner_fp16)
// ->RangeMultiplier(2)
->Ranges({{32768, 32 * 1024 * 1024}, {2, 16}})
->Ranges({{32768, 16 * 1024 * 1024}, {2, 16}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

BENCHMARK(Baseline_Softmax_BWD_Inner_fp16)
// ->RangeMultiplier(2)
->Ranges({{2, 16}, {32768, 32 * 1024 * 1024}})
->Ranges({{2, 16}, {32768, 16 * 1024 * 1024}})
->Unit(benchmark::kMicrosecond)
->UseManualTime();

Expand Down
17 changes: 15 additions & 2 deletions benchmarks/cpp/nvfuser/utils.cpp
Expand Up @@ -6,7 +6,7 @@

using namespace torch::jit::fuser::cuda;

std::string toString(ReductionParams rparams) {
std::string toString(const ReductionParams& rparams) {
std::stringstream ss;
ss << (rparams.fastest_dim ? "Red On Fastest Dim // " : "Red On Slow Dim // ")
<< (rparams.persistent_kernel ? "Persistent Kernel // " : "")
Expand Down Expand Up @@ -65,7 +65,7 @@ std::string toString(ReductionParams rparams) {
return ss.str();
}

std::string toString(PointwiseParams params) {
std::string toString(const PointwiseParams& params) {
std::stringstream ss;
if (params.break_point) {
ss << "2D Schedule at " << params.break_point << "/";
Expand All @@ -89,6 +89,15 @@ std::string toString(PointwiseParams params) {
return ss.str();
}

std::string toString(const TransposeParams& params) {
std::stringstream ss;
ss << "Tile size: (" << params.tile_size1 << "," << params.tile_size2
<< ")/";
ss << "Vectorize size: (" << params.vectorize_factor1 << ","
<< params.vectorize_factor2 << ")";
return ss.str();
}

std::string toString(const std::shared_ptr<HeuristicParams>& params) {
auto rparams = std::dynamic_pointer_cast<ReductionParams>(params);
if (rparams) {
Expand All @@ -98,6 +107,10 @@ std::string toString(const std::shared_ptr<HeuristicParams>& params) {
if (pparams) {
return toString(*pparams);
}
auto tparams = std::dynamic_pointer_cast<TransposeParams>(params);
if (tparams) {
return toString(*tparams);
}
TORCH_INTERNAL_ASSERT(
false,
"Unknown heuristic parameter type. Did you just added a new heuristic parameter type but forget to update here?");
Expand Down
5 changes: 3 additions & 2 deletions benchmarks/cpp/nvfuser/utils.h
Expand Up @@ -36,8 +36,9 @@ TensorView* makeContigConcreteTensor(
std::vector<int64_t> shape,
DataType dtype = DataType::Float);

std::string toString(ReductionParams rparams);
std::string toString(PointwiseParams params);
std::string toString(const ReductionParams& rparams);
std::string toString(const PointwiseParams& params);
std::string toString(const TransposeParams& params);
std::string toString(const std::shared_ptr<HeuristicParams>& params);
std::string toString(LaunchParams lparams);

Expand Down
2 changes: 2 additions & 0 deletions build_variables.bzl
Expand Up @@ -29,6 +29,8 @@ libtorch_nvfuser_runtime_sources = [
"torch/csrc/jit/codegen/cuda/runtime/broadcast.cu",
"torch/csrc/jit/codegen/cuda/runtime/fp16_support.cu",
"torch/csrc/jit/codegen/cuda/runtime/fused_reduction.cu",
"torch/csrc/jit/codegen/cuda/runtime/fused_welford_helper.cu",
"torch/csrc/jit/codegen/cuda/runtime/fused_welford_impl.cu",
"torch/csrc/jit/codegen/cuda/runtime/grid_broadcast.cu",
"torch/csrc/jit/codegen/cuda/runtime/grid_reduction.cu",
"torch/csrc/jit/codegen/cuda/runtime/grid_sync.cu",
Expand Down
8 changes: 8 additions & 0 deletions c10/util/hash.h
Expand Up @@ -304,6 +304,14 @@ struct hash<std::tuple<Types...>> {
}
};

template <typename T1, typename T2>
struct hash<std::pair<T1, T2>> {
size_t operator()(const std::pair<T1, T2>& pair) const {
std::tuple<T1, T2> tuple = std::make_tuple(pair.first, pair.second);
return _hash_detail::simple_get_hash(tuple);
}
};

template <typename T>
struct hash<c10::ArrayRef<T>> {
size_t operator()(c10::ArrayRef<T> v) const {
Expand Down
2 changes: 2 additions & 0 deletions test/cpp/jit/CMakeLists.txt
Expand Up @@ -97,12 +97,14 @@ set(JIT_TEST_SRCS

if(USE_CUDA)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensor_factories.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_fused_reduction.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_shift.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_tensorcore.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_view.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_transpose.cpp)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_rng.cu)
list(APPEND JIT_TEST_SRCS ${TORCH_SRC_DIR}/csrc/jit/codegen/cuda/test/test_gpu_scheduler_utils.cpp)
endif()

add_executable(test_jit
Expand Down
8 changes: 6 additions & 2 deletions test/test_jit_cuda_fuser.py
Expand Up @@ -41,8 +41,12 @@
if RUN_NVFUSER and torch.version.cuda is not None:
CUDA_MAJOR, CUDA_MINOR = (int(x) for x in torch.version.cuda.split('.')[:2])

os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition'
os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,unroll_with_rng'
if 'PYTORCH_NVFUSER_ENABLE' not in os.environ:
os.environ['PYTORCH_NVFUSER_ENABLE'] = ""
os.environ['PYTORCH_NVFUSER_ENABLE'] = 'linear_decomposition,conv_decomposition,' + os.environ['PYTORCH_NVFUSER_ENABLE']
if 'PYTORCH_NVFUSER_DISABLE' not in os.environ:
os.environ['PYTORCH_NVFUSER_DISABLE'] = ""
os.environ['PYTORCH_NVFUSER_DISABLE'] = 'fallback,fma,' + os.environ['PYTORCH_NVFUSER_DISABLE']
os.environ['PYTORCH_NVFUSER_JIT_OPT_LEVEL'] = '0'
# TODO: enable complex when we fixes the extremal cases in OpInfo
# see issue https://github.com/csarofeen/pytorch/issues/1730"
Expand Down

0 comments on commit c8a6173

Please sign in to comment.