diff --git a/aten/src/ATen/native/transformers/cuda/sdp_utils.h b/aten/src/ATen/native/transformers/cuda/sdp_utils.h index 14ea9875c79b..410c314e1397 100644 --- a/aten/src/ATen/native/transformers/cuda/sdp_utils.h +++ b/aten/src/ATen/native/transformers/cuda/sdp_utils.h @@ -379,16 +379,33 @@ inline bool check_gpu_sm50_or_greater(sdp_params params, bool debug) { return true; } -inline bool check_gpu_sm86_head_dim_128(sdp_params params, bool debug) { +inline bool check_head_dim_gt64_and_sm_ge86(sdp_params params, bool debug) { // Memory Efficient Attention is throwing a cuda illegal memory error - // on sm86 when head_dim is 128. + // on sm86 or newer when head_dim is greater than 64. auto dprops = at::cuda::getCurrentDeviceProperties(); - bool is_sm86 = (dprops->major == 8) && (dprops->minor == 6); - if (is_sm86 && (params.query.size(-1) == 128)) { + bool is_sm86_or_newer = (dprops->major == 8) && (dprops->minor >= 6); + // Categorically disable sm90 as well. Will want to fix this once we have H100s available for testing. + is_sm86_or_newer = is_sm86_or_newer || (dprops->major > 8); + if (is_sm86_or_newer && (params.query.sym_size(-1) > 64)) { if (debug) { TORCH_WARN( - "Memory Efficient Attention does not currently support head_dim == 128 on sm86", - "because it is throwing a cuda illegal memory error on sm86 when head_dim is 128."); + "Memory Efficient Attention does not currently support head_dim greater than 64 on sm86 or newer"); + } + return false; + } + return true; +} + +inline bool check_requires_grad_and_head_dim_gt64_and_sm_ge86( + sdp_params params, + bool debug) { + // Flash Attention will raise an error in the backward pass if the head_dim + // size is greater than 64 And the device is sm86 or newer. + if (!check_requires_grad(params, false) && + !check_head_dim_gt64_and_sm_ge86(params, false)) { + if (debug) { + TORCH_WARN( + "Flash attention currently doesn't support training with head_dim greater than 64 on sm86 or newer."); } return false; } @@ -422,13 +439,14 @@ inline bool use_flash_attention(sdp_params params, bool debug) { return false; #endif // Define gate functions that determine if a flash kernel can be ran - constexpr std::array constraints {{ + constexpr std::array constraints {{ check_runtime_disabled_flash, check_tensor_shapes, check_equal_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size, check_gpu_sm75_or_greater, + check_requires_grad_and_head_dim_gt64_and_sm_ge86, check_for_nested_inputs, check_for_seq_len_1_nested_tensor}}; for (auto& constraint : constraints) { @@ -465,7 +483,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) { check_equal_batch_size_and_num_heads, check_for_attn_mask, check_head_dim_size_mem_efficient, - check_gpu_sm86_head_dim_128, + check_head_dim_gt64_and_sm_ge86, check_for_seq_len_1_nested_tensor, check_for_non_zero_dropout, check_use_deterministic_algorithms}}; diff --git a/test/test_transformers.py b/test/test_transformers.py index 28416b4fde00..6e2207c93540 100644 --- a/test/test_transformers.py +++ b/test/test_transformers.py @@ -56,7 +56,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool): default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5} default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6} -isSM86Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 6) +isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)] def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float: @@ -1645,18 +1645,47 @@ def test_sdp_choice_with_determinism(self, warn_only): assert torch._fused_sdp_choice(query, key, value) == ( SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH) - @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable") - def test_memory_efficeint_sm86_failure(self): + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86or89Device, + "Does not support fused SDPA or not SM86+ hardware") + @parametrize("head_dim", [72, 96, 128]) + def test_memory_efficient_sm86_plus_failure(self, head_dim: int): device = 'cuda' dtype = torch.float16 make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype) - # See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h - size = (2, 2, 4, 128) + # See check_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + size = (2, 2, 4, head_dim) q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86or89Device, + "Does not support fused SDPA or not SM86+ hardware") + @parametrize("head_dim", [72, 96, 128]) + def test_flash_backward_failure_sm86plus(self, head_dim: int): + device = 'cuda' + dtype = torch.float16 + make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype) + # See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + size = (2, 2, 4, head_dim) + q, k, v = make_tensor(size), make_tensor(size), make_tensor(size) + + with sdp_kernel(enable_mem_efficient=False, enable_flash=False, enable_math=True): + math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) + + with sdp_kernel(enable_mem_efficient=False, enable_flash=True, enable_math=False): + # Should not fail because inputs don't require grad + flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False) + + self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3) + + # Should fail because inputs require grad + q = make_tensor(size, requires_grad=True) + k = make_tensor(size, requires_grad=True) + v = make_tensor(size, requires_grad=True) + self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention( + q, k, v, None, 0.0, False)) + @unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention") def test_dispatch_fails_no_backend(self): dtype = torch.float16 @@ -1827,7 +1856,7 @@ def func(): @parametrize("batch_size", [1, 8]) @parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048]) @parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048]) - @parametrize("head_dim", [8, 16, 32, 64, 128]) + @parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128]) @parametrize("is_causal", [True, False]) @parametrize("dropout_p", [0.0]) # mem_efficient_attention does not support dropout @parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32]) @@ -1854,8 +1883,8 @@ def test_mem_efficient_attention_vs_math_ref_grads(self, batch_size: int, seq_le # Create real output with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False): - # See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h - if isSM86Device and head_dim == 128: + # See check_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h + if isSM86or89Device and head_dim in range(65, 129): self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value, dropout_p=dropout_p, is_causal=is_causal)) return