Skip to content

Commit

Permalink
[ROCm] Add extra cuda_to_hip_mappings.py (pytorch#125108)
Browse files Browse the repository at this point in the history
Adding extra mappings discovered when hipifying the backward CUDA kernel of the Mamba model (https://github.com/state-spaces/mamba/).

Co-authored-by: Jeff Daily <jeff.daily@amd.com>
Pull Request resolved: pytorch#125108
Approved by: https://github.com/Skylion007, https://github.com/jeffdaily
  • Loading branch information
amoskvic authored and pytorchmergebot committed May 1, 2024
1 parent c8d2a55 commit bf6acf9
Showing 1 changed file with 13 additions and 0 deletions.
13 changes: 13 additions & 0 deletions torch/utils/hipify/cuda_to_hip_mappings.py
Expand Up @@ -643,7 +643,11 @@
("thrust/system/cuda", ("thrust/system/hip", CONV_INCLUDE, API_BLAS)),
("cub/util_allocator.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/block/block_reduce.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/block/block_raking_layout.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/cub.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/config.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/util_ptx.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/util_type.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/device/device_run_length_encode.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/block/block_load.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
("cub/block/block_store.cuh", ("hipcub/hipcub.hpp", CONV_INCLUDE, API_BLAS)),
Expand Down Expand Up @@ -7954,6 +7958,9 @@
("cub::BlockScan", ("hipcub::BlockScan", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::BlockLoad", ("hipcub::BlockLoad", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::BlockStore", ("hipcub::BlockStore", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::BlockRakingLayout", ("hipcub::BlockRakingLayout", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::Uninitialized", ("hipcub::Uninitialized", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::RowMajorTid", ("hipcub::RowMajorTid", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::CachingDeviceAllocator", ("hipcub::CachingDeviceAllocator", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::CountingInputIterator", ("hipcub::CountingInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::DeviceRadixSort", ("hipcub::DeviceRadixSort", CONV_SPECIAL_FUNC, API_RUNTIME)),
Expand All @@ -7967,9 +7974,15 @@
("cub::Max", ("hipcub::Max", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::Min", ("hipcub::Min", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::Sum", ("hipcub::Sum", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::Log2", ("hipcub::Log2", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::LaneId", ("hipcub::LaneId", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::WarpMask", ("hipcub::WarpMask", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::ShuffleIndex", ("hipcub::ShuffleIndex", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::ShuffleDown", ("hipcub::ShuffleDown", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::ArgIndexInputIterator", ("hipcub::ArgIndexInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::TransformInputIterator", ("hipcub::TransformInputIterator", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::WarpReduce", ("hipcub::WarpReduce", CONV_SPECIAL_FUNC, API_RUNTIME)),
("cub::CTA_SYNC", ("hipcub::CTA_SYNC", CONV_SPECIAL_FUNC, API_RUNTIME)),
("nvtxMark", ("roctxMark", CONV_OTHER, API_ROCTX)),
("nvtxMarkA", ("roctxMarkA", CONV_OTHER, API_ROCTX)),
("nvtxRangePushA", ("roctxRangePushA", CONV_OTHER, API_ROCTX)),
Expand Down

0 comments on commit bf6acf9

Please sign in to comment.