Skip to content

Commit

Permalink
OSS: Capture triton kernel in ET (#124775)
Browse files Browse the repository at this point in the history
This DIFF is to capture triton kernels in execution trace

Pull Request resolved: #124775
Approved by: https://github.com/briancoutinho, https://github.com/aaronenyeshi
  • Loading branch information
shengfukevin authored and pytorchmergebot committed Apr 27, 2024
1 parent 8246f42 commit f0a5a0d
Show file tree
Hide file tree
Showing 4 changed files with 122 additions and 39 deletions.
71 changes: 38 additions & 33 deletions test/profiler/test_execution_trace.py
Expand Up @@ -21,6 +21,7 @@

import torch
import torch.nn as nn
from torch import _dynamo as torchdynamo
from torch.autograd import (
_record_function_with_args_enter,
_record_function_with_args_exit,
Expand Down Expand Up @@ -198,6 +199,7 @@ def test_execution_trace_alone(self):
expected_loop_events = 0

et = ExecutionTraceObserver().register_callback(fp.name)

et.start()
for idx in range(5):
expected_loop_events += 1
Expand Down Expand Up @@ -231,50 +233,56 @@ def test_execution_trace_alone(self):
)
@unittest.skipIf(not TEST_CUDA or not has_triton(), "need CUDA and triton to run")
def test_execution_trace_with_pt2(self):
class ConvAndRelu(nn.Module):
def __init__(self) -> None:
super().__init__()
self.linear = nn.Linear(4096, 4096)
self.relu = nn.ReLU(inplace=True)
@torchdynamo.optimize("inductor")
def fn(a, b, c):
x = torch.nn.functional.linear(a, b)
x = x + c
return x.cos()

a, b, c = (torch.randn(4, 4, requires_grad=True).to("cuda") for _ in range(3))

def forward(self, x: torch.Tensor) -> torch.Tensor:
x = self.linear(x)
x = self.relu(x)
return x
inputs = [a, b, c]
with torch._inductor.config.patch(compile_threads=1):
fn(*inputs)

# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix="_et.json", delete=False)
fp.close()

with torch._inductor.config.patch(compile_threads=1):
test_module = torch.compile(ConvAndRelu())

x = torch.rand(128, 4096)
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
test_module.forward(x)
et.stop()
with profile(
activities=torch.profiler.supported_activities(),
record_shapes=True,
schedule=torch.profiler.schedule(
skip_first=3, wait=1, warmup=1, active=2, repeat=1
),
execution_trace_observer=(
ExecutionTraceObserver().register_callback(fp.name)
),
) as p:
for idx in range(10):
with record_function(f"## LOOP {idx} ##"):
fn(*inputs)
p.step()

assert fp.name == et.get_output_file_path()
et.unregister_callback()
nodes = self.get_execution_trace_root(fp.name)

found_root_node = False
found_captured_triton_kernel_node = False
for n in nodes:
assert "name" in n
if "[pytorch|profiler|execution_trace|process]" in n["name"]:
found_root_node = True

assert found_root_node
if "triton_" in n["name"]:
for attr in n["attrs"]:
if attr["name"] == "kernel_file" and attr["value"] != "":
found_captured_triton_kernel_node = True
assert len(n["inputs"]["values"]) > 0
assert len(n["outputs"]["values"]) == 0
assert found_captured_triton_kernel_node

def test_execution_trace_start_stop(self):
use_cuda = torch.profiler.ProfilerActivity.CUDA in supported_activities()
# Create a temp file to save execution trace data.
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
expected_loop_events = 0
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et = ExecutionTraceObserver().register_callback(fp.name)
for idx in range(10):
if idx == 3:
et.start()
Expand Down Expand Up @@ -314,8 +322,7 @@ def test_execution_trace_repeat_in_loop(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
output_files.append(fp.name)
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et = ExecutionTraceObserver().register_callback(fp.name)
et.start()
with record_function(f"## LOOP {idx} ##"):
self.payload(use_cuda=use_cuda)
Expand All @@ -340,8 +347,7 @@ def test_execution_trace_repeat_in_loop(self):
def test_execution_trace_no_capture(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()
et = ExecutionTraceObserver()
et.register_callback(fp.name)
et = ExecutionTraceObserver().register_callback(fp.name)

assert fp.name == et.get_output_file_path()
et.unregister_callback()
Expand All @@ -357,8 +363,7 @@ def test_execution_trace_nested_tensor(self):
fp = tempfile.NamedTemporaryFile("w+t", suffix=".et.json", delete=False)
fp.close()

et = ExecutionTraceObserver()
observer = et.register_callback(fp.name)
observer = ExecutionTraceObserver().register_callback(fp.name)

def fn(nt):
return nt.sin().cos()
Expand Down
2 changes: 1 addition & 1 deletion torch/_inductor/runtime/triton_heuristics.py
Expand Up @@ -805,7 +805,7 @@ def run(self, *args, grid, stream, **kwargs):
args,
{
"kernel_file": self.filename,
"kernel_type": "triton",
"kernel_backend": "triton",
"grid": grid_info,
"stream": stream,
},
Expand Down
59 changes: 54 additions & 5 deletions torch/csrc/profiler/standalone/execution_trace_observer.cpp
Expand Up @@ -236,6 +236,8 @@ const ExecutionTraceObserver::ID root_id{1};

struct FunctionCallContext : public ObserverContext {
std::string name;
std::string kernel_backend;
std::string kernel_file;
ExecutionTraceObserver::ID op_id{uninitialized_id};
ExecutionTraceObserver::ID parent_id{uninitialized_id};
ExecutionTraceObserver::ID fw_parent_id{uninitialized_id};
Expand Down Expand Up @@ -273,14 +275,16 @@ static void writeJsonNode(
const std::string& outputs = "[]",
const std::string& output_shapes = "[]",
const std::string& output_types = "[]",
const std::string& operator_schema = "") {
const std::string& operator_schema = "",
const std::string& kernel_backend = "",
const std::string& kernel_file = "") {
out << fmt::format(
R"JSON(
{{
"id": {}, "name": "{}", "ctrl_deps": {},
"inputs": {{"values": {}, "shapes": {}, "types": {}}},
"outputs": {{"values": {}, "shapes": {}, "types": {}}},
"attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}}, {{"name": "fw_parent", "type": "uint64", "value": {}}}, {{"name": "seq_id", "type": "int64", "value": {}}}, {{"name": "scope", "type": "uint64", "value": {}}}, {{"name": "tid", "type": "uint64", "value": {}}}, {{"name": "fw_tid", "type": "uint64", "value": {}}}, {{"name": "op_schema", "type": "string", "value": "{}"}}]
"attrs": [{{"name": "rf_id", "type": "uint64", "value": {}}},{{"name": "fw_parent", "type": "uint64", "value": {}}},{{"name": "seq_id", "type": "int64", "value": {}}},{{"name": "scope", "type": "uint64", "value": {}}},{{"name": "tid", "type": "uint64", "value": {}}},{{"name": "fw_tid", "type": "uint64", "value": {}}},{{"name": "op_schema", "type": "string", "value": "{}"}},{{"name": "kernel_backend", "type": "string", "value": "{}"}},{{"name": "kernel_file", "type": "string", "value": "{}"}}]
}})JSON",
id,
name,
Expand All @@ -297,7 +301,9 @@ static void writeJsonNode(
scope,
tid,
fw_tid,
operator_schema);
operator_schema,
kernel_backend,
kernel_file);
}

inline std::string timeString(const std::time_t timepoint) {
Expand Down Expand Up @@ -326,7 +332,7 @@ static bool initExecutionTraceStart(ExecutionTraceObserver& ob) {

ob.out << fmt::format(
R"JSON({{
"schema": "1.0.3-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {},
"schema": "1.0.4-chakra.0.0.4", "pid": {}, "time": "{}", "start_ts": {},
"nodes": [)JSON",
ob.pid,
ob.record_time,
Expand Down Expand Up @@ -442,6 +448,44 @@ inline void appendValueInfo(
shapes.push_back(getValueShape(val));
}

inline void handleKernelBackendInfo(
FunctionCallContext& fc,
const RecordFunction& fn) {
// triton kernel related information are in kwinputs
const auto& kwinputs = fn.kwinputs();
if (kwinputs.find("kernel_backend") != kwinputs.end()) {
fc.kernel_backend = kwinputs.at("kernel_backend").toStringRef();
if (fc.kernel_backend == "triton") {
fc.kernel_file = kwinputs.at("kernel_file").toStringRef();
TORCH_INTERNAL_ASSERT(
kwinputs.find("kernel_file") != kwinputs.end(),
"kernel file is missing in triton kernel");
// Remove the path of the file name
if (fc.kernel_file.find_last_of('/') != std::string::npos)
fc.kernel_file =
fc.kernel_file.substr(fc.kernel_file.find_last_of('/') + 1);

// get grid information
TORCH_INTERNAL_ASSERT(
kwinputs.find("grid") != kwinputs.end(),
"grid is missing in triton kernel");
fc.input_values.emplace_back(
"\"" + kwinputs.at("grid").toStringRef() + "\"");
fc.input_types.emplace_back("\"String\"");
fc.input_shapes.emplace_back("[]");

// get stream information
TORCH_INTERNAL_ASSERT(
kwinputs.find("stream") != kwinputs.end(),
"stream is missing in triton kernel");
fc.input_values.emplace_back(
std::to_string(kwinputs.at("stream").toInt()));
fc.input_types.emplace_back("\"Int\"");
fc.input_shapes.emplace_back("[]");
}
}
}

static void recordOperatorStart(
ExecutionTraceObserver& ob,
FunctionCallContext& fc,
Expand Down Expand Up @@ -491,6 +535,9 @@ static void recordOperatorStart(
appendValueInfo(
ob, inputs[i], fc.input_values, fc.input_types, fc.input_shapes);
}

handleKernelBackendInfo(fc, fn);

fc.parent_id = ob.op_stack[tid].top();
// get parent id from the forward stack, this can be different for
// autograd ops, which may execute on a different thread than the original
Expand Down Expand Up @@ -615,7 +662,9 @@ static void onFunctionExit(const RecordFunction& fn, ObserverContext* ctx_ptr) {
vectorToString(output_values),
vectorToString(output_shapes),
vectorToString(output_types),
op_schema_str);
op_schema_str,
fc.kernel_backend,
fc.kernel_file);
ob->out << ",";
} catch (const std::exception& e) {
LOG(WARNING) << "Exception in execution trace observer: [" << fc.name
Expand Down
29 changes: 29 additions & 0 deletions torch/profiler/profiler.py
@@ -1,6 +1,7 @@
import gzip
import json
import os
import shutil
import tempfile
from abc import ABC, abstractmethod
from enum import Enum
Expand Down Expand Up @@ -792,8 +793,36 @@ def unregister_callback(self):
"""
Removes ET observer from record function callbacks.
"""

def _save_triton_kernels():
# Save the kernel paths for the generated kernels
from torch._inductor.codecache import PyCodeCache as PyCodeCache

kernel_files = [
v.__file__
for v in PyCodeCache.cache.values()
if getattr(v, "__file__", None) is not None
]
work_dir, file_name = os.path.split(self._output_file_path)
resource_dir = os.path.join(
work_dir, os.path.splitext(file_name)[0] + "_resources"
)
if not os.path.exists(resource_dir):
os.mkdir(resource_dir)

for kernel_file in kernel_files:
if kernel_file is None:
continue
path, name = os.path.split(kernel_file)
dst = os.path.join(resource_dir, name)
shutil.copyfile(kernel_file, dst)

if self._registered:
self.stop()
try:
_save_triton_kernels()
except Exception as e:
warn(f"Execution trace failed to save kernels: {e}")
_remove_execution_trace_observer()
self._registered = False

Expand Down

0 comments on commit f0a5a0d

Please sign in to comment.