Skip to content

Commit

Permalink
[RelEng] Define BUILD_BUNDLE_PTXAS (#119750) (#119988)
Browse files Browse the repository at this point in the history
Co-authored-by: Nikita Shulga <nshulga@meta.com>
Fixes #119054
resolved: #119750
  • Loading branch information
atalman and malfet committed Feb 15, 2024
1 parent f00f0ab commit 6c8c5ad
Show file tree
Hide file tree
Showing 2 changed files with 26 additions and 0 deletions.
11 changes: 11 additions & 0 deletions CMakeLists.txt
Expand Up @@ -349,6 +349,8 @@ cmake_dependent_option(
"NOT INTERN_BUILD_MOBILE" OFF)
cmake_dependent_option(
BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF)
cmake_dependent_option(
BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" OFF "USE_CUDA" OFF)

option(USE_MIMALLOC "Use mimalloc" OFF)
# Enable third party mimalloc library to improve memory allocation performance on Windows.
Expand Down Expand Up @@ -1230,3 +1232,12 @@ if(DEFINED USE_CUSTOM_DEBINFO)
set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -g")
set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -g")
endif()

# Bundle PTXAS if needed
if(BUILD_BUNDLE_PTXAS AND USE_CUDA)
if(NOT EXISTS "${PROJECT_SOURCE_DIR}/build/bin/ptxas")
message(STATUS "Copying PTXAS into the bin folder")
file(COPY "${CUDAToolkit_BIN_DIR}/ptxas" DESTINATION "${PROJECT_BINARY_DIR}")
endif()
install(PROGRAMS "${PROJECT_BINARY_DIR}/ptxas" DESTINATION "${CMAKE_INSTALL_BINDIR}")
endif()
15 changes: 15 additions & 0 deletions torch/_inductor/codecache.py
Expand Up @@ -2277,6 +2277,20 @@ def caching_device_properties():
device_interface.Worker.get_device_properties()


def _set_triton_ptxas_path() -> None:
if os.environ.get("TRITON_PTXAS_PATH") is not None:
return
ptxas_path = os.path.abspath(
os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas")
)
if not os.path.exists(ptxas_path):
return
if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK):
os.environ["TRITON_PTXAS_PATH"] = ptxas_path
else:
warnings.warn(f"{ptxas_path} exists but is not an executable")


def _worker_compile(
kernel_name: str, source_code: str, cc: int, device: torch.device
) -> None:
Expand All @@ -2287,6 +2301,7 @@ def _worker_compile(


def _load_kernel(kernel_name: str, source_code: str) -> ModuleType:
_set_triton_ptxas_path()
kernel = TritonCodeCache.load(kernel_name, source_code)
kernel.precompile()
return kernel
Expand Down

0 comments on commit 6c8c5ad

Please sign in to comment.