Skip to content

Commit

Permalink
PR #4912: [ROCm] support for GraphAddKernelNode()
Browse files Browse the repository at this point in the history
Imported from GitHub PR openxla/xla#4912

patch for this PR openxla/xla#4894 (comment)

@akuegel @ezhulenev Thanks in advance!
Copybara import of the project:

--
27fe0e2e859b5188704f21613283e68d594f4d92 by Chao Chen <cchen104@amd.com>:

rocm graph adds GraphAddKernelNode()

Merging this change closes #4912

PiperOrigin-RevId: 556107848
  • Loading branch information
i-chaochen authored and tensorflower-gardener committed Aug 11, 2023
1 parent 8146900 commit 89fac52
Show file tree
Hide file tree
Showing 3 changed files with 88 additions and 23 deletions.
77 changes: 55 additions & 22 deletions tensorflow/compiler/xla/stream_executor/gpu/gpu_driver.h
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

// CUDA userspace driver library wrapper functionality.
// CUDA/ROCm userspace driver library wrapper functionality.

#ifndef TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
#define TENSORFLOW_COMPILER_XLA_STREAM_EXECUTOR_GPU_GPU_DRIVER_H_
Expand Down Expand Up @@ -49,63 +49,72 @@ class GpuContext;
// The calls log any specific errors internally and return whether the operation
// was successful to the caller.
//
// The order of parameters is generally kept symmetric with the underlying CUDA
// driver API.
// The order of parameters is generally kept symmetric with the underlying
// CUDA/ROCm driver API.
//
// Links on functions are to specific documentation under
// http://docs.nvidia.com/cuda/cuda-driver-api/
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html
//
// Thread safety: these functions should not be used from signal handlers.
class GpuDriver {
public:
// Wraps a call to cuInit with logging to help indicate what has gone wrong in
// the case of failure. Safe to call multiple times; will be fast on all calls
// after the first.
// Wraps a call to cuInit/hipInit with logging to help indicate what has gone
// wrong in the case of failure. Safe to call multiple times; will be fast on
// all calls after the first.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__INITIALIZE.html#group__CUDA__INITIALIZE_1g0a2f1517e1bd8502c7194c3a8c134bc3
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#initialization
static tsl::Status Init();

// Returns the device associated with the given context.
// device is an outparam owned by the caller, must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g4e84b109eba36cdaaade167f34ae881e
static tsl::StatusOr<GpuDeviceHandle> DeviceFromContext(GpuContext* context);

// Creates a new CUDA stream associated with the given context via
// cuStreamCreate.
// Creates a new CUDA/HIP stream associated with the given context via
// cuStreamCreate/hipStreamCreateWithFlags.
// stream is an outparam owned by the caller, must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1ga581f0c5833e21ded8b5a56594e243f4
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management
static bool CreateStream(GpuContext* context, GpuStreamHandle* stream,
int priority = 0);

// Destroys a CUDA stream associated with the given context.
// Destroys a CUDA/HIP stream associated with the given context.
// stream is owned by the caller, must not be null, and *stream is set to null
// if the stream is successfully destroyed.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g244c8833de4596bcd31a06cdf21ee758
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management
static void DestroyStream(GpuContext* context, GpuStreamHandle* stream);

// CUDA events can explicitly disable event TSC retrieval for some presumed
// performance improvement if timing is unnecessary.
// CUDA/HIP events can explicitly disable event TSC retrieval for some
// presumed performance improvement if timing is unnecessary.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types
enum class EventFlags { kDefault, kDisableTiming };

// Creates a new event associated with the given context.
// result is an outparam owned by the caller and must not be null.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g450687e75f3ff992fe01662a43d9d3db
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types
static tsl::Status InitEvent(GpuContext* context, GpuEventHandle* result,
EventFlags flags);

// Destroys *event and turns it into a nullptr. event may not be null, but
// *event may be, via cuEventDestroy
// *event may be, via cuEventDestroy/hipEventDestroy
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EVENT.html#group__CUDA__EVENT_1g593ec73a8ec5a5fc031311d3e4dca1ef
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#event-management
static tsl::Status DestroyEvent(GpuContext* context, GpuEventHandle* event);

// Allocates a GPU memory space of size bytes associated with the given
// context via cuMemAlloc.
// context via cuMemAlloc/hipMalloc.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gb82d2a09844a58dd9e744dc31e8aa467
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static void* DeviceAllocate(GpuContext* context, uint64_t bytes);

// Deallocates a GPU memory space of size bytes associated with the given
// context via cuMemFree.
// context via cuMemFree/hipFree.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g89b3f154e17cc89b6eea277dbdf5c93a
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static void DeviceDeallocate(GpuContext* context, void* location);

// Allocates a unified memory space of size bytes associated with the given
Expand All @@ -121,31 +130,38 @@ class GpuDriver {
static void UnifiedMemoryDeallocate(GpuContext* context, void* location);

// Allocates page-locked and CUDA-registered memory on the host via
// cuMemAllocHost.
// cuMemAllocHost/hipHostMalloc.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gdd8311286d2c2691605362c689bc64e0
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static void* HostAllocate(GpuContext* context, uint64_t bytes);

// Deallocates a location created by HostAllocate, via cuMemFreeHost.
// Deallocates a location created by HostAllocate, via
// cuMemFreeHost/hipHostFree.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g62e0fdbe181dab6b1c90fa1a51c7b92c
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static void HostDeallocate(GpuContext* context, void* location);

// Registers a memory region at location of size bytes via cuMemHostRegister.
// Registers a memory region at location of size bytes via
// cuMemHostRegister/hipHostRegister.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1gf0a9fe11544326dabd743b7aa6b54223
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
static bool HostRegister(GpuContext* context, void* location, uint64_t bytes);

// Unregisters a memory region that was previously registered at location via
// cuMemHostUnregister.
// cuMemHostUnregister/hipHostUnregister.
//
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__MEM.html#group__CUDA__MEM_1g63f450c8125359be87b7623b1c0b2a14
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#memory-management
//
// TODO(leary) verify an error will be returned if the location wasn't
// previously registered.
static bool HostUnregister(GpuContext* context, void* location);

// Queries the priority range and returns the corresponding integer value via
// cuCtxGetStreamPriorityRange
// cuCtxGetStreamPriorityRange/hipDeviceGetStreamPriorityRange
//
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g137920ab61a71be6ce67605b9f294091
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#context-management
static int GetGpuStreamPriority(
GpuContext* context, stream_executor::StreamPriority stream_priority);

Expand Down Expand Up @@ -211,7 +227,7 @@ class GpuDriver {
// which must not be null.
//
// N.B. these device handles do not have a corresponding destroy function in
// the CUDA driver API.
// the CUDA/HIP driver API.
static tsl::Status GetDevice(int device_ordinal, GpuDeviceHandle* device);

// Given a device handle, returns the name reported by the driver for the
Expand Down Expand Up @@ -257,19 +273,22 @@ class GpuDriver {
// Gets the preferred shared memory bank configuration for the specified
// CONTEXT (not function!), either default or four- or eight-byte bank size.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g17153a1b8b8c756f7ab8505686a4ad74
// https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html
static tsl::StatusOr<GpuSharedMemConfig> ContextGetSharedMemConfig(
GpuContext* context);

// Sets the preferred shared memory bank configuration for the specified
// CONTEXT (not function!), either default or four- or eight-byte bank size.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__CTX.html#group__CUDA__CTX_1g2574235fa643f8f251bf7bc28fac3692
// https://rocm.docs.amd.com/projects/HIP/en/latest/.doxygen/docBin/html/group___execution.html
static tsl::Status ContextSetSharedMemConfig(
GpuContext* context, GpuSharedMemConfig shared_mem_config);

// Launches a CUDA kernel via cuLaunchKernel.
// Launches a CUDA/ROCm kernel via cuLaunchKernel/hipModuleLaunchKernel.
// TODO(leary) describe the structure of kernel_params and extra in a readable
// way.
// http://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__EXEC.html#group__CUDA__EXEC_1gb8f3dc3031b40da29d5f9a7139e52e15
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#execution-control
static tsl::Status LaunchKernel(
GpuContext* context, absl::string_view kernel_name,
GpuFunctionHandle function, unsigned int grid_dim_x,
Expand All @@ -280,25 +299,30 @@ class GpuDriver {

// Creates a new GPU graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gd885f719186010727b75c3315f865fdf
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status CreateGraph(GpuGraphHandle* graph);

// Destroys GPU graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g718cfd9681f078693d4be2426fd689c8
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status DestroyGraph(GpuGraphHandle graph);

// Begins graph capture on a stream.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g767167da0bbf07157dc20b6c258a2143
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
enum class StreamCaptureMode { kGlobal, kThreadLocal, kRelaxed };
static tsl::Status StreamBeginCapture(GpuStreamHandle stream,
StreamCaptureMode mode);

// Ends capture on a stream, returning the captured graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g03dab8b2ba76b00718955177a929970c
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status StreamEndCapture(GpuStreamHandle stream,
GpuGraphHandle* graph);

// Graph instantiation flags.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g070bf5517d3a7915667c256eefce4956
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types
struct GraphInstantiateFlags {
// Automatically free memory allocated in a graph before relaunching.
bool auto_free_on_launch = false;
Expand All @@ -313,17 +337,20 @@ class GpuDriver {

// Creates an executable graph from a graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1gb53b435e178cccfa37ac87285d2c3fa1
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status GraphInstantiate(GpuGraphExecHandle* exec,
GpuGraphHandle graph,
const GraphInstantiateFlags& flags);

// Launches an executable graph in a stream.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g6b2dceb3901e71a390d2bd8b0491e471
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status GraphLaunch(GpuGraphExecHandle exec,
GpuStreamHandle stream);

// Graph update result.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__TYPES.html#group__CUDA__TYPES_1g8edc8969ff6ae00b7cd5d7292f812c3c
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#cuda-driver-data-types
enum class GraphExecUpdateResult {
kSuccess,
kError,
Expand All @@ -338,6 +365,7 @@ class GpuDriver {

// Graph update result info.
// https://docs.nvidia.com/cuda/cuda-driver-api/structCUgraphExecUpdateResultInfo__v1.html#structCUgraphExecUpdateResultInfo__v1
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
struct GraphExecUpdateResultInfo {
// TODO(ezhulenev): Add `errorFromNode` and `errorNode` members.
GraphExecUpdateResult result;
Expand All @@ -346,26 +374,31 @@ class GpuDriver {
// Check whether an executable graph can be updated with a graph and perform
// the update if possible.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g96efefc56df46927da7297f122adfb9f
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status GraphExecUpdate(GpuGraphExecHandle exec,
GpuGraphHandle graph,
GraphExecUpdateResultInfo* result);

// Destroys an executable graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1ga32ad4944cc5d408158207c978bc43a7
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status DestroyGraphExec(GpuGraphExecHandle exec);

// Write a DOT file describing graph structure.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g0fb0c4d319477a0a98da005fcb0dacc4
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status GraphDebugDotPrint(GpuGraphHandle graph, const char* path);

// Returns a stream's capture status.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__STREAM.html#group__CUDA__STREAM_1g37823c49206e3704ae23c7ad78560bca
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#stream-management
static tsl::StatusOr<bool> StreamIsCapturing(GpuStreamHandle stream);

// Creates a kernel execution node and adds it to a graph.
// https://docs.nvidia.com/cuda/cuda-driver-api/group__CUDA__GRAPH.html#group__CUDA__GRAPH_1g50d871e3bd06c1b835e52f2966ef366b
// https://rocm.docs.amd.com/projects/HIPIFY/en/latest/tables/CUDA_Driver_API_functions_supported_by_HIP.html#graph-management
static tsl::Status GraphAddKernelNode(
CUgraphNode* node, GpuGraphHandle graph,
GpuGraphNodeHandle* node, GpuGraphHandle graph,
absl::Span<GpuGraphNodeHandle> deps, absl::string_view kernel_name,
GpuFunctionHandle function, unsigned int grid_dim_x,
unsigned int grid_dim_y, unsigned int grid_dim_z,
Expand Down
32 changes: 32 additions & 0 deletions tensorflow/compiler/xla/stream_executor/rocm/rocm_driver.cc
Original file line number Diff line number Diff line change
Expand Up @@ -558,6 +558,38 @@ static std::string_view StreamCaptureModeToString(
return status == hipStreamCaptureStatusActive;
}

/* static */ tsl::Status GpuDriver::GraphAddKernelNode(
hipGraphNode_t* node, hipGraph_t graph, absl::Span<hipGraphNode_t> deps,
absl::string_view kernel_name, hipFunction_t function,
unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
unsigned int block_dim_x, unsigned int block_dim_y,
unsigned int block_dim_z, unsigned int shared_mem_bytes,
void** kernel_params, void** extra) {
VLOG(2) << "Add kernel node to a graph: " << graph
<< "; kernel: " << kernel_name << "; gdx: " << grid_dim_x
<< " gdy: " << grid_dim_y << " gdz: " << grid_dim_z
<< " bdx: " << block_dim_x << " bdy: " << block_dim_y
<< " bdz: " << block_dim_z << "; shmem: " << shared_mem_bytes;

hipKernelNodeParams params;
params.func = function;
params.gridDim.x = grid_dim_x;
params.gridDim.y = grid_dim_y;
params.gridDim.z = grid_dim_z;
params.blockDim.x = block_dim_x;
params.blockDim.y = block_dim_y;
params.blockDim.z = block_dim_z;
params.sharedMemBytes = shared_mem_bytes;
params.kernelParams = kernel_params;
params.extra = extra;

RETURN_IF_ROCM_ERROR(
hipGraphAddKernelNode(node, graph, deps.data(), deps.size(), &params),
"Failed to add kernel node to a HIP graph");

return ::tsl::OkStatus();
}

/* static */ tsl::Status GpuDriver::LaunchKernel(
GpuContext* context, absl::string_view kernel_name, hipFunction_t function,
unsigned int grid_dim_x, unsigned int grid_dim_y, unsigned int grid_dim_z,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -93,13 +93,13 @@ namespace wrap {
__macro(hipGetDeviceProperties) \
__macro(hipGetErrorString) \
__macro(hipGraphDebugDotPrint) \
__macro(hipGraphDebugDotFlagsVerbose) \
__macro(hipGraphDestroy) \
__macro(hipGraphExecDestroy) \
__macro(hipGraphExecUpdate) \
__macro(hipGraphInstantiate) \
__macro(hipGraphLaunch) \
__macro(hipGraphCreate) \
__macro(hipGraphAddKernelNode) \
__macro(hipHostFree) \
__macro(hipHostMalloc) \
__macro(hipHostRegister) \
Expand Down

0 comments on commit 89fac52

Please sign in to comment.