Skip to content

Commit

Permalink
Revert "Require less alignment for attn bias (#114173) (#114837)"
Browse files Browse the repository at this point in the history
This reverts commit 5965649.
  • Loading branch information
atalman committed Dec 12, 2023
1 parent 448700d commit a8e7c98
Show file tree
Hide file tree
Showing 8 changed files with 43 additions and 192 deletions.
38 changes: 24 additions & 14 deletions aten/src/ATen/native/transformers/attention.cpp
Expand Up @@ -590,14 +590,9 @@ c10::optional<Tensor> convert_boolean_attn_mask(const c10::optional<Tensor>& att
// We apply this function to the top level SDPA so that
// if padding is done it will be tracked for backward automatically

template<int alignment>
bool aligned_tensor(const at::Tensor& tensor){
for(const auto i : c10::irange(tensor.dim() - 1)){
if(tensor.sym_stride(i) % alignment != 0){
return false;
}
}
return tensor.sym_stride(-1) == 1;
template <int alignment>
bool is_aligned(const SymInt& size){
return size % alignment == 0;
}

template <int alignment>
Expand All @@ -613,16 +608,31 @@ at::Tensor preprocess_mask(
const at::Tensor& query,
const at::Tensor& key,
const at::Tensor& value) {
constexpr int mem_eff_alignment = 8;
at::Tensor result_mask = mask;
if (!aligned_tensor<mem_eff_alignment>(mask)) {
result_mask = pad_bias<mem_eff_alignment>(mask);
}
return result_mask.expand_symint(
constexpr int mem_eff_alignment = 16;
// Expand to 4d case
at::Tensor attn_mask = mask.expand_symint(
{query.sym_size(0),
query.sym_size(1),
query.sym_size(2),
key.sym_size(2)});

bool aligned_last_dim = is_aligned<mem_eff_alignment>(attn_mask.sym_size(-1));
// Apply pad_bias and store the result in attn_mask
if (!aligned_last_dim) {
return pad_bias<mem_eff_alignment>(attn_mask);
}
// Check and make the tensor contiguous if needed
auto needs_contig = [](const c10::SymInt& stride) {
return (stride % 16 != 0) || (stride == 0);
};
if (needs_contig(attn_mask.sym_stride(0)) ||
needs_contig(attn_mask.sym_stride(1)) ||
needs_contig(attn_mask.sym_stride(2)) ||
needs_contig(attn_mask.sym_stride(3))) {
return attn_mask.contiguous();
}

return attn_mask;
}

} // namespace
Expand Down
Expand Up @@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel {
"value is not correctly aligned (strideH)");
TORCH_CHECK(
p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0,
"query is not correctly aligned (strideB).");
"query is not correctly aligned (strideB)");
TORCH_CHECK(
p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0,
"key is not correctly aligned (strideB)");
Expand All @@ -1216,19 +1216,13 @@ struct AttentionBackwardKernel {
if (p.bias_ptr) {
TORCH_CHECK(
p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideB). ",
"attn_bias.stride(0) = ", p.bias_strideB, ", and should be a "
"multiple of ", kMinimumAlignment, ".");
"attn_bias is not correctly aligned (strideB)");
TORCH_CHECK(
p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideH) ."
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
"multiple of ", kMinimumAlignment, ".");
"attn_bias is not correctly aligned (strideH)");
TORCH_CHECK(
p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideM). "
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ",
"multiple of ", kMinimumAlignment, ".");
p.bias_strideM % kMinimumAlignment == 0,
"attn_bias is not correctly aligned (strideM)");
}
if (p.grad_bias_ptr) {
TORCH_CHECK(
Expand Down
Expand Up @@ -578,19 +578,13 @@ struct AttentionKernel {
CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ);
TORCH_CHECK(
p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideB). ",
"attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a "
"multiple of ", kAlignmentQ, ".");
"attn_bias is not correctly aligned (strideB)");
TORCH_CHECK(
p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideH). "
"attn_bias.stride(1) = ", p.bias_strideH, ", and should be a "
"multiple of ", kAlignmentQ, ".");
"attn_bias is not correctly aligned (strideH)");
TORCH_CHECK(
p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned (strideM). "
"attn_bias.stride(2) = ", p.bias_strideM, ", and should be a "
"multiple of ", kAlignmentQ, ".");
p.bias_strideM % kAlignmentQ == 0,
"attn_bias is not correctly aligned");
}
TORCH_CHECK(
p.q_strideM % kAlignmentQ == 0,
Expand Down
46 changes: 0 additions & 46 deletions test/inductor/test_torchinductor.py
Expand Up @@ -6441,52 +6441,6 @@ def fn_or(x, y):
(torch.randn(32), torch.randn(32)),
)

@requires_cuda()
@unittest.skipIf(
not PLATFORM_SUPPORTS_FUSED_SDPA,
"Does not support mem_eff_attention",
)
@skipIfRocm
def test_sdpa_unaligned_mask(self):
def foo(
arg0_1: "f32[8, 8, 16, 16]",
arg1_1: "f32[8, 8, 15, 16]",
arg2_1: "f32[8, 8, 15, 16]",
arg3_1: "f32[1, 1, 16, 15]",
):
constant_pad_nd: "f32[1, 1, 16, 16]" = (
torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0)
)
arg3_1 = None
slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor(
constant_pad_nd, -1, 0, 15
)
constant_pad_nd = None
expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default(
slice_1, [8, 8, 16, 15]
)
slice_1 = None
_scaled_dot_product_efficient_attention = (
torch.ops.aten._scaled_dot_product_efficient_attention.default(
arg0_1, arg1_1, arg2_1, expand, False
)
)
arg0_1 = arg1_1 = arg2_1 = expand = None
getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0]
_scaled_dot_product_efficient_attention = None
return (getitem,)

query = torch.rand(8, 8, 16, 16, device="cuda")
key = torch.rand(8, 8, 15, 16, device="cuda")
value = torch.rand(8, 8, 15, 16, device="cuda")
bias = torch.rand(1, 1, 16, 15, device="cuda")
self.common(
foo,
(query, key, value, bias),
atol=0.02,
rtol=1e4,
)

@skipIfRocm
def test_conv_with_as_strided(self):
class Model(nn.Module):
Expand Down
1 change: 0 additions & 1 deletion test/inductor/test_torchinductor_codegen_dynamic_shapes.py
Expand Up @@ -259,7 +259,6 @@ def run(*ex, **kwargs):
"test_zero_dim_reductions_dynamic_shapes": TestFailure(
("cpu", "cuda"), is_skip=True
),
"test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True),
#
# The following tests do not support dynamic shapes yet:
#
Expand Down
19 changes: 1 addition & 18 deletions test/test_transformers.py
Expand Up @@ -1760,6 +1760,7 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype):
out = F.scaled_dot_product_attention(query, key, value, mask)
out.sum().backward()


@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
Expand Down Expand Up @@ -1800,24 +1801,6 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_

self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
def test_mem_eff_attention_non_contig_mask_bug(self, device):
dtype = torch.float32
make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True)
batch, num_heads, head_dim = 1, 16, 128
seq_len_q, seq_len_kv = 1, 16
query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2)
kv_shape = (batch, seq_len_kv, head_dim)
key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1)
key = key.expand(-1, num_heads, -1, -1)
value = value.expand(-1, num_heads, -1, -1)
mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool)
with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]):
out = F.scaled_dot_product_attention(query, key, value, mask)
out_no_mask = F.scaled_dot_product_attention(query, key, value, None)
max_diff = (out - out_no_mask).abs().mean()
assert max_diff.item() < 1e-9

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system")
@parametrize("type", ["dense", "nested"])
@parametrize("is_contiguous", [True, False])
Expand Down
89 changes: 4 additions & 85 deletions torch/_inductor/lowering.py
Expand Up @@ -1866,91 +1866,10 @@ def apply_constraint(arg, fx_arg):
make_fallback(aten._fused_moving_avg_obs_fq_helper_functional)
make_fallback(aten.grid_sampler_2d_backward, require_dense)
make_fallback(aten.randperm)


def sdpa_constraint(fx_node, *args, **kwargs):
# sdpa requires dense last dimension
def apply_constraint(arg, fx_arg):
if not isinstance(arg, ir.IRNode):
return arg

meta_val = fx_arg.meta["val"]
if not meta_val.is_cuda:
return arg

stride_order = ir.get_stride_order(meta_val.stride())
if stride_order and stride_order[-1] != 0:
# contiguous stride order
stride_order = list(reversed(range(len(arg.get_size()))))

# This is the minimum alignment required by SDPA kernels for attention_bias.
# This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask
ALIGNMENT = 8

is_backward = fx_node.target in (
aten._scaled_dot_product_efficient_attention_backward.default,
aten._scaled_dot_product_flash_attention_backward.default,
)

def is_aligned(x):
return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0

assert isinstance(arg, TensorBox)

# This correctly handles the forward case:
if isinstance(arg.data, (ir.SliceView, ir.ExpandView)):
if not is_aligned(arg):
# input is padded, requiring_stride_order will unwrap the view and unpad.
# Would be nice to be able to require certain padding from inductor ir, nyi
if is_aligned(arg.unwrap_view()):
return arg

def is_aligned_backward(x):
aligned_strides = all(
(V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0
for i in range(len(x.get_stride()) - 1)
)
return (
V.graph.sizevars.size_hint(x.get_stride()[-1])
) == 1 and aligned_strides

if (
isinstance(arg.data, ir.StorageBox)
and arg.data.is_input_buffer()
and is_backward
):
if len(arg.data.get_size()) == 4 and is_aligned_backward(arg):
return arg

return ir.ExternKernel.require_stride_order(arg, stride_order)

args = tuple(
apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args)
)
kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()}
return args, kwargs


make_fallback(
aten._scaled_dot_product_efficient_attention,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_efficient_attention_backward,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_flash_attention,
sdpa_constraint,
warn=False,
)
make_fallback(
aten._scaled_dot_product_flash_attention_backward,
sdpa_constraint,
warn=False,
)
make_fallback(aten._scaled_dot_product_efficient_attention)
make_fallback(aten._scaled_dot_product_efficient_attention_backward)
make_fallback(aten._scaled_dot_product_flash_attention)
make_fallback(aten._scaled_dot_product_flash_attention_backward)
make_fallback(aten.sort)
make_fallback(aten.sort.stable)
make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors)
Expand Down
12 changes: 5 additions & 7 deletions torch/_meta_registrations.py
Expand Up @@ -4985,14 +4985,12 @@ def meta__scaled_dot_product_efficient_backward(
)
grad_bias = None
if attn_bias is not None and grad_input_mask[3]:
lastDim = attn_bias.size(-1)
lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16
new_sizes = list(attn_bias.size())
new_sizes[-1] = lastDimAligned
grad_bias = torch.empty(
new_sizes, dtype=attn_bias.dtype, device=attn_bias.device
grad_bias = torch.empty_strided(
attn_bias.size(),
attn_bias.stride(),
dtype=attn_bias.dtype,
device=attn_bias.device,
)
grad_bias = grad_bias[..., :lastDim]

return grad_q, grad_k, grad_v, grad_bias

Expand Down

0 comments on commit a8e7c98

Please sign in to comment.