diff --git a/aten/src/ATen/native/transformers/attention.cpp b/aten/src/ATen/native/transformers/attention.cpp index df40dec2d3d8..988d79fe2dba 100644 --- a/aten/src/ATen/native/transformers/attention.cpp +++ b/aten/src/ATen/native/transformers/attention.cpp @@ -590,14 +590,9 @@ c10::optional convert_boolean_attn_mask(const c10::optional& att // We apply this function to the top level SDPA so that // if padding is done it will be tracked for backward automatically -template -bool aligned_tensor(const at::Tensor& tensor){ - for(const auto i : c10::irange(tensor.dim() - 1)){ - if(tensor.sym_stride(i) % alignment != 0){ - return false; - } - } - return tensor.sym_stride(-1) == 1; +template +bool is_aligned(const SymInt& size){ + return size % alignment == 0; } template @@ -613,16 +608,31 @@ at::Tensor preprocess_mask( const at::Tensor& query, const at::Tensor& key, const at::Tensor& value) { - constexpr int mem_eff_alignment = 8; - at::Tensor result_mask = mask; - if (!aligned_tensor(mask)) { - result_mask = pad_bias(mask); - } - return result_mask.expand_symint( + constexpr int mem_eff_alignment = 16; + // Expand to 4d case + at::Tensor attn_mask = mask.expand_symint( {query.sym_size(0), query.sym_size(1), query.sym_size(2), key.sym_size(2)}); + + bool aligned_last_dim = is_aligned(attn_mask.sym_size(-1)); + // Apply pad_bias and store the result in attn_mask + if (!aligned_last_dim) { + return pad_bias(attn_mask); + } + // Check and make the tensor contiguous if needed + auto needs_contig = [](const c10::SymInt& stride) { + return (stride % 16 != 0) || (stride == 0); + }; + if (needs_contig(attn_mask.sym_stride(0)) || + needs_contig(attn_mask.sym_stride(1)) || + needs_contig(attn_mask.sym_stride(2)) || + needs_contig(attn_mask.sym_stride(3))) { + return attn_mask.contiguous(); + } + + return attn_mask; } } // namespace diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h index ec5a4a8a6ef5..0a5bb1db0433 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_backward.h @@ -1197,7 +1197,7 @@ struct AttentionBackwardKernel { "value is not correctly aligned (strideH)"); TORCH_CHECK( p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, - "query is not correctly aligned (strideB)."); + "query is not correctly aligned (strideB)"); TORCH_CHECK( p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, "key is not correctly aligned (strideB)"); @@ -1216,19 +1216,13 @@ struct AttentionBackwardKernel { if (p.bias_ptr) { TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideB). ", - "attn_bias.stride(0) = ", p.bias_strideB, ", and should be a " - "multiple of ", kMinimumAlignment, "."); + "attn_bias is not correctly aligned (strideB)"); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideH) ." - "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " - "multiple of ", kMinimumAlignment, "."); + "attn_bias is not correctly aligned (strideH)"); TORCH_CHECK( - p.num_queries <= 1 || p.bias_strideM % kMinimumAlignment == 0, - "attn_bias is not correctly aligned (strideM). " - "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a ", - "multiple of ", kMinimumAlignment, "."); + p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM)"); } if (p.grad_bias_ptr) { TORCH_CHECK( diff --git a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h index 2e81480086d9..3a8189af09c4 100644 --- a/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h +++ b/aten/src/ATen/native/transformers/cuda/mem_eff_attention/kernel_forward.h @@ -578,19 +578,13 @@ struct AttentionKernel { CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); TORCH_CHECK( p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideB). ", - "attn_bias.stride( 0) = ", p.bias_strideB, ", and should be a " - "multiple of ", kAlignmentQ, "."); + "attn_bias is not correctly aligned (strideB)"); TORCH_CHECK( p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideH). " - "attn_bias.stride(1) = ", p.bias_strideH, ", and should be a " - "multiple of ", kAlignmentQ, "."); + "attn_bias is not correctly aligned (strideH)"); TORCH_CHECK( - p.num_queries <= 1 || p.bias_strideM % kAlignmentQ == 0, - "attn_bias is not correctly aligned (strideM). " - "attn_bias.stride(2) = ", p.bias_strideM, ", and should be a " - "multiple of ", kAlignmentQ, "."); + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); } TORCH_CHECK( p.q_strideM % kAlignmentQ == 0, diff --git a/test/inductor/test_torchinductor.py b/test/inductor/test_torchinductor.py index f326a3d2b5c3..e6f1ea5a0561 100644 --- a/test/inductor/test_torchinductor.py +++ b/test/inductor/test_torchinductor.py @@ -6441,52 +6441,6 @@ def fn_or(x, y): (torch.randn(32), torch.randn(32)), ) - @requires_cuda() - @unittest.skipIf( - not PLATFORM_SUPPORTS_FUSED_SDPA, - "Does not support mem_eff_attention", - ) - @skipIfRocm - def test_sdpa_unaligned_mask(self): - def foo( - arg0_1: "f32[8, 8, 16, 16]", - arg1_1: "f32[8, 8, 15, 16]", - arg2_1: "f32[8, 8, 15, 16]", - arg3_1: "f32[1, 1, 16, 15]", - ): - constant_pad_nd: "f32[1, 1, 16, 16]" = ( - torch.ops.aten.constant_pad_nd.default(arg3_1, [0, 1], 0.0) - ) - arg3_1 = None - slice_1: "f32[1, 1, 16, 15]" = torch.ops.aten.slice.Tensor( - constant_pad_nd, -1, 0, 15 - ) - constant_pad_nd = None - expand: "f32[8, 8, 16, 15]" = torch.ops.aten.expand.default( - slice_1, [8, 8, 16, 15] - ) - slice_1 = None - _scaled_dot_product_efficient_attention = ( - torch.ops.aten._scaled_dot_product_efficient_attention.default( - arg0_1, arg1_1, arg2_1, expand, False - ) - ) - arg0_1 = arg1_1 = arg2_1 = expand = None - getitem: "f32[8, 8, 16, 16]" = _scaled_dot_product_efficient_attention[0] - _scaled_dot_product_efficient_attention = None - return (getitem,) - - query = torch.rand(8, 8, 16, 16, device="cuda") - key = torch.rand(8, 8, 15, 16, device="cuda") - value = torch.rand(8, 8, 15, 16, device="cuda") - bias = torch.rand(1, 1, 16, 15, device="cuda") - self.common( - foo, - (query, key, value, bias), - atol=0.02, - rtol=1e4, - ) - @skipIfRocm def test_conv_with_as_strided(self): class Model(nn.Module): diff --git a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py index e6d3dd27b996..21ff02aced7c 100644 --- a/test/inductor/test_torchinductor_codegen_dynamic_shapes.py +++ b/test/inductor/test_torchinductor_codegen_dynamic_shapes.py @@ -259,7 +259,6 @@ def run(*ex, **kwargs): "test_zero_dim_reductions_dynamic_shapes": TestFailure( ("cpu", "cuda"), is_skip=True ), - "test_sdpa_unaligned_mask_dynamic_shapes": TestFailure(("cpu",), is_skip=True), # # The following tests do not support dynamic shapes yet: # diff --git a/test/test_transformers.py b/test/test_transformers.py index 18ed523558e2..1314755450fe 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -1760,6 +1760,7 @@ def test_mem_eff_attention_long_sequence_mask(self, device, dtype): out = F.scaled_dot_product_attention(query, key, value, mask) out.sum().backward() + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) @@ -1800,24 +1801,6 @@ def test_scaled_dot_product_attention_fused_kernels(self, device, type: str, is_ self.assertEqual(actual[0].contiguous(), math_ref[0].contiguous(), atol=1e-3, rtol=1e-2) - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") - def test_mem_eff_attention_non_contig_mask_bug(self, device): - dtype = torch.float32 - make_tensor = partial(torch.rand, device=device, dtype=dtype, requires_grad=True) - batch, num_heads, head_dim = 1, 16, 128 - seq_len_q, seq_len_kv = 1, 16 - query = make_tensor(batch, seq_len_q, num_heads * head_dim).view(batch, seq_len_q, num_heads, head_dim).transpose(1, 2) - kv_shape = (batch, seq_len_kv, head_dim) - key, value = make_tensor(kv_shape).unsqueeze(1), make_tensor(kv_shape).unsqueeze(1) - key = key.expand(-1, num_heads, -1, -1) - value = value.expand(-1, num_heads, -1, -1) - mask = torch.ones((1, 1, seq_len_q, seq_len_kv), device=device, dtype=torch.bool) - with sdp_kernel(**backend_map[SDPBackend.EFFICIENT_ATTENTION]): - out = F.scaled_dot_product_attention(query, key, value, mask) - out_no_mask = F.scaled_dot_product_attention(query, key, value, None) - max_diff = (out - out_no_mask).abs().mean() - assert max_diff.item() < 1e-9 - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Fused SDPA was not built for this system") @parametrize("type", ["dense", "nested"]) @parametrize("is_contiguous", [True, False]) diff --git a/torch/_inductor/lowering.py b/torch/_inductor/lowering.py index a8d58cda0d7e..13f8a0dfc1a8 100644 --- a/torch/_inductor/lowering.py +++ b/torch/_inductor/lowering.py @@ -1866,91 +1866,10 @@ def apply_constraint(arg, fx_arg): make_fallback(aten._fused_moving_avg_obs_fq_helper_functional) make_fallback(aten.grid_sampler_2d_backward, require_dense) make_fallback(aten.randperm) - - -def sdpa_constraint(fx_node, *args, **kwargs): - # sdpa requires dense last dimension - def apply_constraint(arg, fx_arg): - if not isinstance(arg, ir.IRNode): - return arg - - meta_val = fx_arg.meta["val"] - if not meta_val.is_cuda: - return arg - - stride_order = ir.get_stride_order(meta_val.stride()) - if stride_order and stride_order[-1] != 0: - # contiguous stride order - stride_order = list(reversed(range(len(arg.get_size())))) - - # This is the minimum alignment required by SDPA kernels for attention_bias. - # This value can be found in pytorch/aten/src/ATen/native/transformers/attention.cpp preprocess_mask - ALIGNMENT = 8 - - is_backward = fx_node.target in ( - aten._scaled_dot_product_efficient_attention_backward.default, - aten._scaled_dot_product_flash_attention_backward.default, - ) - - def is_aligned(x): - return (V.graph.sizevars.size_hint(x.get_size()[-1]) % ALIGNMENT) == 0 - - assert isinstance(arg, TensorBox) - - # This correctly handles the forward case: - if isinstance(arg.data, (ir.SliceView, ir.ExpandView)): - if not is_aligned(arg): - # input is padded, requiring_stride_order will unwrap the view and unpad. - # Would be nice to be able to require certain padding from inductor ir, nyi - if is_aligned(arg.unwrap_view()): - return arg - - def is_aligned_backward(x): - aligned_strides = all( - (V.graph.sizevars.size_hint(x.get_stride()[i]) % ALIGNMENT) == 0 - for i in range(len(x.get_stride()) - 1) - ) - return ( - V.graph.sizevars.size_hint(x.get_stride()[-1]) - ) == 1 and aligned_strides - - if ( - isinstance(arg.data, ir.StorageBox) - and arg.data.is_input_buffer() - and is_backward - ): - if len(arg.data.get_size()) == 4 and is_aligned_backward(arg): - return arg - - return ir.ExternKernel.require_stride_order(arg, stride_order) - - args = tuple( - apply_constraint(arg, fx_arg) for arg, fx_arg in zip(args, fx_node.args) - ) - kwargs = {k: apply_constraint(v, fx_node.kwargs[k]) for k, v in kwargs.items()} - return args, kwargs - - -make_fallback( - aten._scaled_dot_product_efficient_attention, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_efficient_attention_backward, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_flash_attention, - sdpa_constraint, - warn=False, -) -make_fallback( - aten._scaled_dot_product_flash_attention_backward, - sdpa_constraint, - warn=False, -) +make_fallback(aten._scaled_dot_product_efficient_attention) +make_fallback(aten._scaled_dot_product_efficient_attention_backward) +make_fallback(aten._scaled_dot_product_flash_attention) +make_fallback(aten._scaled_dot_product_flash_attention_backward) make_fallback(aten.sort) make_fallback(aten.sort.stable) make_fallback(aten._sparse_coo_tensor_with_dims_and_tensors) diff --git a/torch/_meta_registrations.py b/torch/_meta_registrations.py index 0dee68f46c24..56593916a4a1 100644 --- a/torch/_meta_registrations.py +++ b/torch/_meta_registrations.py @@ -4985,14 +4985,12 @@ def meta__scaled_dot_product_efficient_backward( ) grad_bias = None if attn_bias is not None and grad_input_mask[3]: - lastDim = attn_bias.size(-1) - lastDimAligned = lastDim if lastDim % 16 == 0 else lastDim + 16 - lastDim % 16 - new_sizes = list(attn_bias.size()) - new_sizes[-1] = lastDimAligned - grad_bias = torch.empty( - new_sizes, dtype=attn_bias.dtype, device=attn_bias.device + grad_bias = torch.empty_strided( + attn_bias.size(), + attn_bias.stride(), + dtype=attn_bias.dtype, + device=attn_bias.device, ) - grad_bias = grad_bias[..., :lastDim] return grad_q, grad_k, grad_v, grad_bias