diff --git a/CMakeLists.txt b/CMakeLists.txt index 9194e520bb00..386b3a208c9a 100644 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -349,6 +349,8 @@ cmake_dependent_option( "NOT INTERN_BUILD_MOBILE" OFF) cmake_dependent_option( BUILD_FUNCTORCH "Build Functorch" ON "BUILD_PYTHON" OFF) +cmake_dependent_option( + BUILD_BUNDLE_PTXAS "Bundle PTX into torch/bin fodler" OFF "USE_CUDA" OFF) option(USE_MIMALLOC "Use mimalloc" OFF) # Enable third party mimalloc library to improve memory allocation performance on Windows. @@ -1230,3 +1232,12 @@ if(DEFINED USE_CUSTOM_DEBINFO) set(CMAKE_EXE_LINKER_FLAGS "${CMAKE_EXE_LINKER_FLAGS} -g") set(CMAKE_SHARED_LINKER_FLAGS "${CMAKE_SHARED_LINKER_FLAGS} -g") endif() + +# Bundle PTXAS if needed +if(BUILD_BUNDLE_PTXAS AND USE_CUDA) + if(NOT EXISTS "${PROJECT_SOURCE_DIR}/build/bin/ptxas") + message(STATUS "Copying PTXAS into the bin folder") + file(COPY "${CUDAToolkit_BIN_DIR}/ptxas" DESTINATION "${PROJECT_BINARY_DIR}") + endif() + install(PROGRAMS "${PROJECT_BINARY_DIR}/ptxas" DESTINATION "${CMAKE_INSTALL_BINDIR}") +endif() diff --git a/torch/_inductor/codecache.py b/torch/_inductor/codecache.py index f26b8fa4993e..4a41e8d5b887 100644 --- a/torch/_inductor/codecache.py +++ b/torch/_inductor/codecache.py @@ -2277,6 +2277,20 @@ def caching_device_properties(): device_interface.Worker.get_device_properties() +def _set_triton_ptxas_path() -> None: + if os.environ.get("TRITON_PTXAS_PATH") is not None: + return + ptxas_path = os.path.abspath( + os.path.join(os.path.dirname(__file__), "..", "bin", "ptxas") + ) + if not os.path.exists(ptxas_path): + return + if os.path.isfile(ptxas_path) and os.access(ptxas_path, os.X_OK): + os.environ["TRITON_PTXAS_PATH"] = ptxas_path + else: + warnings.warn(f"{ptxas_path} exists but is not an executable") + + def _worker_compile( kernel_name: str, source_code: str, cc: int, device: torch.device ) -> None: @@ -2287,6 +2301,7 @@ def _worker_compile( def _load_kernel(kernel_name: str, source_code: str) -> ModuleType: + _set_triton_ptxas_path() kernel = TritonCodeCache.load(kernel_name, source_code) kernel.precompile() return kernel