Skip to content

Commit

Permalink
[inductor][cpp] support bf16/fp16 gemm template epilogue fusion
Browse files Browse the repository at this point in the history
[ghstack-poisoned]
  • Loading branch information
jgong5 committed May 17, 2024
1 parent 91bf952 commit 2547656
Show file tree
Hide file tree
Showing 9 changed files with 291 additions and 42 deletions.
77 changes: 72 additions & 5 deletions test/inductor/test_cpu_select_algorithm.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ def forward(self, x):
"div",
),
)
@dtypes(torch.float)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_pointwise(self, bias, epilogue, dtype):
batch_size = 384
in_features = 196
Expand Down Expand Up @@ -204,16 +204,25 @@ def forward(self, x):
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, epilogue=epilogue, other=u).to(dtype=dtype).eval()
self.common(mod, (v,))
atol, rtol = 1e-4, 1e-4
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)
if dtype == torch.half or dtype == torch.bfloat16:
# For half and bfloat16, the epilogue fusion is part of the template,
# not fused via scheduler.
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 0)
else:
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@dtypes(torch.float)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_transpose(self, bias, dtype):
batch_size = 384
in_features = 196
Expand All @@ -231,7 +240,65 @@ def forward(self, x, y):
mod = M(bias=bias).to(dtype=dtype).eval()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(out_features, batch_size).to(dtype=dtype)
self.common(mod, (v, u))
atol, rtol = 1e-4, 1e-4
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(mod, (v, u), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

@inductor_config.patch({"freezing": True})
@patches
@torch.no_grad
@unittest.skipIf(not TEST_MKL, "Test requires MKL")
@parametrize("bias", (True, False))
@parametrize(
"unary",
("relu",),
)
@parametrize(
"binary",
(
"add",
"sub",
"mul",
"div",
),
)
@dtypes(torch.float, torch.bfloat16, torch.half)
def test_linear_with_unary_binary(self, bias, unary, binary, dtype):
batch_size = 384
in_features = 196
out_features = 384

class M(torch.nn.Module):
def __init__(self, bias, unary, binary, other):
super().__init__()
self.linear = torch.nn.Linear(in_features, out_features, bias)
if unary == "relu":
self.unary = torch.nn.ReLU()
if binary == "add":
self.binary = lambda x: x + other
elif binary == "sub":
self.binary = lambda x: x - other
elif binary == "mul":
self.binary = lambda x: x * other
elif binary == "div":
self.binary = lambda x: x / other

def forward(self, x):
return self.binary(self.unary(self.linear(x)))

counters.clear()
v = torch.randn(batch_size, in_features).to(dtype=dtype)
u = torch.randn(batch_size, out_features).to(dtype=dtype)
mod = M(bias=bias, unary=unary, binary=binary, other=u).to(dtype=dtype).eval()
atol, rtol = 1e-4, 1e-4
if dtype == torch.half or dtype == torch.bfloat16:
atol, rtol = 1e-2, 1e-2
with patch.object(select_algorithm, "VERIFY", dict(atol=atol, rtol=rtol)):
self.common(mod, (v,), atol=atol, rtol=rtol)
self.assertEqual(counters["inductor"]["select_algorithm_autotune"], 1)
self.assertEqual(counters["inductor"]["cpp_epilogue_fusion_counter"], 1)

Expand Down
4 changes: 4 additions & 0 deletions torch/_inductor/codegen/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -1731,6 +1731,10 @@ def rename_indexing(self, index) -> sympy.Expr:
def create_cse_var(self, *args, **kwargs):
return CSEVariable(*args, **kwargs)

def make_inplace(self, input_name, output_name):
self.args.make_inplace(input_name, output_name)
self.inplace_update_buffers[output_name] = input_name


@dataclasses.dataclass
class OptimizationContext:
Expand Down
15 changes: 7 additions & 8 deletions torch/_inductor/codegen/cpp.py
Original file line number Diff line number Diff line change
Expand Up @@ -2966,7 +2966,7 @@ def is_lowp_fp_scheduler(self, scheduler_node: SchedulerNode):
scheduler_node._lowp_fp_type = _lowp_fp_type # type: ignore[attr-defined]
return True

def legalize_lowp_fp_dtype(self, nodes):
def legalize_lowp_fp_dtype_loop_body(self, loop_body: ir.LoopBody):
def add_to_dtype(sub_graph: torch.fx.Graph):
def is_lowp_fp_load(node: torch.fx.Node):
if node.target not in ["load"]:
Expand Down Expand Up @@ -3104,11 +3104,11 @@ def _used_by_to(to_node: torch.fx.Node):

eliminate_to_dtype(sub_graph)

def _legalize_lowp_fp(loop_body: ir.LoopBody):
sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
for sub_block in sub_blocks:
add_to_dtype(sub_block.graph)
sub_blocks = [loop_body.root_block] + list(loop_body.subblocks.values())
for sub_block in sub_blocks:
add_to_dtype(sub_block.graph)

def legalize_lowp_fp_dtype(self, nodes):
if all(
isinstance(_node, SchedulerNode) and self.is_lowp_fp_scheduler(_node)
for _node in nodes
Expand Down Expand Up @@ -3145,7 +3145,7 @@ def is_memory_copy_scheduler_node(node: SchedulerNode):
should_legalize = not is_memory_copy_scheduler_node(node)
if should_legalize:
body: ir.LoopBody = node._body
_legalize_lowp_fp(body)
self.legalize_lowp_fp_dtype_loop_body(body)

def codegen_functions(self, fn_list, var_sizes_list, vec_dtype=torch.float):
# TODO(jgong5): remove vec_dtype arg with alternative tiling factors for various dtypes
Expand Down Expand Up @@ -3310,8 +3310,8 @@ def select_tiling(dtype: torch.dtype = torch.float):
inner_tail_loop.set_kernel(vec_kernel)

def codegen_loop_bodies(self, loop_bodies, var_sizes_list):
# TODO(jgong5): support lowp legalization
for body in loop_bodies:
self.legalize_lowp_fp_dtype_loop_body(body)
DataTypePropagation.propagate_loopbody(body)
self.codegen_functions(loop_bodies, var_sizes_list)

Expand Down Expand Up @@ -3714,7 +3714,6 @@ def codegen_template(
kernel, render = ctb.make_kernel_render(ctb, epilogue_nodes=epilogue_ir_nodes)
with kernel:
for node in [template_node, *epilogue_nodes]:
node.decide_inplace_update()
node.mark_run()
src_code = render()

Expand Down
52 changes: 43 additions & 9 deletions torch/_inductor/codegen/cpp_gemm_template.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,4 @@
from typing import cast, List, Optional
from typing import Callable, cast, List, Optional

import torch
import torch.utils
Expand All @@ -20,7 +20,7 @@
{{micro_gemm.codegen_define(kernel)}}
extern "C"
{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y})}}
{{kernel.def_kernel(inputs={"X": X, "W": W, "inp": inp}, outputs={"Y": Y}, aliases=buffer_aliases)}}
{
{{kernel.maybe_codegen_profile()}}
constexpr int64_t num_threads = {{num_threads}};
Expand Down Expand Up @@ -90,8 +90,8 @@
const int64_t n_start = nc * N0;
const int64_t n_size = N0;
{%- if use_local_acc %}
{{ kernel.define_buffer("acc_local_buf", ["m_end - m_start", "N0"]) }}
{%- set acc = kernel.local_buffers["acc_local_buf"] %}
{{ kernel.define_buffer(acc_buf_name, ["m_end - m_start", "N0"]) }}
{%- set acc = kernel.local_buffers[acc_buf_name] %}
{%- else %}
{%- set acc = kernel.slice_nd(GemmOut, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{%- endif %}
Expand Down Expand Up @@ -126,7 +126,7 @@
{%- endif %}
{%- set tile_Y = kernel.slice_nd(Y_maybe_transposed, [("m_start", "m_end"), ("n_start", "n_start + N0")]) %}
{{ kernel.store_output(
tile_Y, acc, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
tile_Y, acc, GemmOut, epilogue_nodes, offsets=("m_start", "n_start"), reindexer=reindexer
)|indent(16, false)
}}
}
Expand All @@ -145,9 +145,12 @@ def __init__(
register_blocking: GemmBlocking,
beta=1,
alpha=1,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
assert layout.dtype in [torch.float, torch.bfloat16, torch.half]
super().__init__("packed_gemm", input_nodes, layout)
super().__init__(
"packed_gemm", input_nodes, layout, epilogue_creator=epilogue_creator
)
self.beta = beta
self.alpha = alpha
self.num_threads = num_threads
Expand Down Expand Up @@ -219,6 +222,7 @@ def add_choices(
alpha=1,
trans_w=False,
input_indices=None,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
if input_indices is None:
input_indices = list(range(len(input_nodes)))
Expand Down Expand Up @@ -372,6 +376,7 @@ def postprocessor(output):
register_blocking=micro_gemm.register_blocking,
beta=beta,
alpha=alpha,
epilogue_creator=epilogue_creator,
)
template.maybe_append_choice(choices)
return template
Expand All @@ -395,11 +400,36 @@ def render( # type: ignore[override]
Y = template_buffer_node

template_buffer = Y
gemm_output_buffer = template_buffer

epilogues: List[ir.IRNode] = []
if self.epilogue_creator is not None:
# We assume the epilogues are computed with fp32 before storing back to Y,
# so we set fp32 data type for input and output of the epilogue.
# In the codegen, they would be either replaced with Y for fp32 output or
# replaced with acc for bf16/fp16 output. In either case, they are fp32.
gemm_layout = ir.FixedLayout(
template_buffer.layout.device,
torch.float,
template_buffer.layout.size,
template_buffer.layout.stride,
)
gemm_output_name = "GemmOut"
gemm_output_buffer = ir.Buffer(gemm_output_name, gemm_layout)
epilogues.append(
ir.ComputedBuffer(
name=template_buffer.get_name(),
layout=gemm_layout if epilogue_nodes else template_buffer.layout,
data=self.epilogue_creator(gemm_output_buffer),
)
)

Y_is_transposed = False
use_local_acc = self.layout.dtype != torch.float
acc_buf_name = "local_acc_buf"
if epilogue_nodes:
epilogues.extend(epilogue_nodes)
Y = cast(ir.Buffer, epilogue_nodes[-1])
assert Y.get_name() in V.kernel.inplace_update_buffers
if Y.get_stride() == list(reversed(template_buffer.get_stride())):
Y_is_transposed = True

Expand All @@ -421,16 +451,20 @@ def render( # type: ignore[override]
W=W,
inp=inp,
Y=Y,
GemmOut=template_buffer,
GemmOut=gemm_output_buffer,
buffer_aliases=[(gemm_output_buffer, Y)]
if gemm_output_buffer is not Y
else None,
beta=self.beta,
alpha=self.alpha,
num_threads=self.num_threads,
micro_gemm=micro_gemm,
is_dynamic_M=self.is_dynamic_M,
template=self,
kernel=kernel,
epilogue_nodes=epilogue_nodes,
epilogue_nodes=epilogues,
reindexer=(lambda x: list(reversed(x))) if Y_is_transposed else None,
use_local_acc=use_local_acc,
acc_buf_name=acc_buf_name,
)
return self._template_from_string(GEMM_TEMPLATE).render(**options)
4 changes: 3 additions & 1 deletion torch/_inductor/codegen/cpp_template.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import logging

import sys
from typing import List, Optional
from typing import Callable, List, Optional
from unittest.mock import patch

import sympy
Expand All @@ -26,11 +26,13 @@ def __init__(
name: str,
input_nodes,
layout: ir.Layout,
epilogue_creator: Optional[Callable[[ir.Buffer], ir.Pointwise]] = None,
):
super().__init__(name)
self.input_nodes = input_nodes
self.output_node: ir.Buffer = ir.Buffer("buf_out", layout)
self.layout = layout
self.epilogue_creator = epilogue_creator

def generate(self, **kwargs):
kernel_name = f"cpp_{self.name}"
Expand Down

0 comments on commit 2547656

Please sign in to comment.