Skip to content

Commit

Permalink
[2.0.1] Disable SDPA FlashAttention backward and mem eff attention on…
Browse files Browse the repository at this point in the history
… sm86+ for head_dim above 64 (#99736)

* Disable SDPA FlashAttention backward and mem eff attention on sm86+ for head_dim above 64 (#99105)

Expand sdpa_utils.h check to disable FlashAttention when using autograd and mem eff attention for the following cases
- head_dim > 64
- sm86 or newer

Previously we only disable these kernels on sm86 and for head_dim equal to 128.

Pull Request resolved: #99105
Approved by: https://github.com/malfet

* remove master only test

---------

Co-authored-by: albanD <desmaison.alban@gmail.com>
  • Loading branch information
cpuhrsch and albanD committed Apr 24, 2023
1 parent 9e8bd61 commit e9ebda2
Show file tree
Hide file tree
Showing 2 changed files with 63 additions and 16 deletions.
34 changes: 26 additions & 8 deletions aten/src/ATen/native/transformers/cuda/sdp_utils.h
Expand Up @@ -379,16 +379,33 @@ inline bool check_gpu_sm50_or_greater(sdp_params params, bool debug) {
return true;
}

inline bool check_gpu_sm86_head_dim_128(sdp_params params, bool debug) {
inline bool check_head_dim_gt64_and_sm_ge86(sdp_params params, bool debug) {
// Memory Efficient Attention is throwing a cuda illegal memory error
// on sm86 when head_dim is 128.
// on sm86 or newer when head_dim is greater than 64.
auto dprops = at::cuda::getCurrentDeviceProperties();
bool is_sm86 = (dprops->major == 8) && (dprops->minor == 6);
if (is_sm86 && (params.query.size(-1) == 128)) {
bool is_sm86_or_newer = (dprops->major == 8) && (dprops->minor >= 6);
// Categorically disable sm90 as well. Will want to fix this once we have H100s available for testing.
is_sm86_or_newer = is_sm86_or_newer || (dprops->major > 8);
if (is_sm86_or_newer && (params.query.sym_size(-1) > 64)) {
if (debug) {
TORCH_WARN(
"Memory Efficient Attention does not currently support head_dim == 128 on sm86",
"because it is throwing a cuda illegal memory error on sm86 when head_dim is 128.");
"Memory Efficient Attention does not currently support head_dim greater than 64 on sm86 or newer");
}
return false;
}
return true;
}

inline bool check_requires_grad_and_head_dim_gt64_and_sm_ge86(
sdp_params params,
bool debug) {
// Flash Attention will raise an error in the backward pass if the head_dim
// size is greater than 64 And the device is sm86 or newer.
if (!check_requires_grad(params, false) &&
!check_head_dim_gt64_and_sm_ge86(params, false)) {
if (debug) {
TORCH_WARN(
"Flash attention currently doesn't support training with head_dim greater than 64 on sm86 or newer.");
}
return false;
}
Expand Down Expand Up @@ -422,13 +439,14 @@ inline bool use_flash_attention(sdp_params params, bool debug) {
return false;
#endif
// Define gate functions that determine if a flash kernel can be ran
constexpr std::array<bool(*)(sdp_params, bool), 8> constraints {{
constexpr std::array<bool(*)(sdp_params, bool), 9> constraints {{
check_runtime_disabled_flash,
check_tensor_shapes,
check_equal_batch_size_and_num_heads,
check_for_attn_mask,
check_head_dim_size,
check_gpu_sm75_or_greater,
check_requires_grad_and_head_dim_gt64_and_sm_ge86,
check_for_nested_inputs,
check_for_seq_len_1_nested_tensor}};
for (auto& constraint : constraints) {
Expand Down Expand Up @@ -465,7 +483,7 @@ inline bool use_mem_efficient_attention(sdp_params params, bool debug) {
check_equal_batch_size_and_num_heads,
check_for_attn_mask,
check_head_dim_size_mem_efficient,
check_gpu_sm86_head_dim_128,
check_head_dim_gt64_and_sm_ge86,
check_for_seq_len_1_nested_tensor,
check_for_non_zero_dropout,
check_use_deterministic_algorithms}};
Expand Down
45 changes: 37 additions & 8 deletions test/test_transformers.py
Expand Up @@ -56,7 +56,7 @@ def use_deterministic_algorithims(mode: bool, warn_only: bool):
default_atol = {torch.float16: 1e-3, torch.bfloat16: 1e-3, torch.float32: 1e-5}
default_rtol = {torch.float16: 1e-3, torch.bfloat16: 1.6e-2, torch.float32: 1.3e-6}

isSM86Device = torch.cuda.is_available() and torch.cuda.get_device_capability() == (8, 6)
isSM86or89Device = torch.cuda.is_available() and torch.cuda.get_device_capability() in [(8, 6), (8, 9)]


def get_rtol(true_value: torch.Tensor, computed_value: torch.Tensor) -> float:
Expand Down Expand Up @@ -1645,18 +1645,47 @@ def test_sdp_choice_with_determinism(self, warn_only):
assert torch._fused_sdp_choice(query, key, value) == (
SDPBackend.EFFICIENT_ATTENTION if warn_only else SDPBackend.MATH)

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86Device, "CUDA unavailable")
def test_memory_efficeint_sm86_failure(self):
@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86or89Device,
"Does not support fused SDPA or not SM86+ hardware")
@parametrize("head_dim", [72, 96, 128])
def test_memory_efficient_sm86_plus_failure(self, head_dim: int):
device = 'cuda'
dtype = torch.float16
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, 128)
# See check_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, head_dim)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA or not isSM86or89Device,
"Does not support fused SDPA or not SM86+ hardware")
@parametrize("head_dim", [72, 96, 128])
def test_flash_backward_failure_sm86plus(self, head_dim: int):
device = 'cuda'
dtype = torch.float16
make_tensor = partial(self.rand_tensor, type="dense", device=device, dtype=dtype)
# See check_requires_grad_and_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
size = (2, 2, 4, head_dim)
q, k, v = make_tensor(size), make_tensor(size), make_tensor(size)

with sdp_kernel(enable_mem_efficient=False, enable_flash=False, enable_math=True):
math_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)

with sdp_kernel(enable_mem_efficient=False, enable_flash=True, enable_math=False):
# Should not fail because inputs don't require grad
flash_ref = torch.nn.functional.scaled_dot_product_attention(q, k, v, None, 0.0, False)

self.assertEqual(math_ref, flash_ref, atol=1e-3, rtol=1e-3)

# Should fail because inputs require grad
q = make_tensor(size, requires_grad=True)
k = make_tensor(size, requires_grad=True)
v = make_tensor(size, requires_grad=True)
self.assertRaises(RuntimeError, lambda: torch.nn.functional.scaled_dot_product_attention(
q, k, v, None, 0.0, False))

@unittest.skipIf(not PLATFORM_SUPPORTS_FUSED_SDPA, "Does not support fused scaled dot product attention")
def test_dispatch_fails_no_backend(self):
dtype = torch.float16
Expand Down Expand Up @@ -1827,7 +1856,7 @@ def func():
@parametrize("batch_size", [1, 8])
@parametrize("seq_len_q", [4, 8, 64, 128, 256, 512, 1024, 2048])
@parametrize("seq_len_k", [4, 8, 64, 128, 256, 512, 1024, 2048])
@parametrize("head_dim", [8, 16, 32, 64, 128])
@parametrize("head_dim", [8, 16, 32, 64, 72, 96, 128])
@parametrize("is_causal", [True, False])
@parametrize("dropout_p", [0.0]) # mem_efficient_attention does not support dropout
@parametrize("dtype", [torch.float16, torch.bfloat16, torch.float32])
Expand All @@ -1854,8 +1883,8 @@ def test_mem_efficient_attention_vs_math_ref_grads(self, batch_size: int, seq_le

# Create real output
with sdp_kernel(enable_mem_efficient=True, enable_flash=False, enable_math=False):
# See check_gpu_sm86_head_dim_128 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
if isSM86Device and head_dim == 128:
# See check_head_dim_gt64_and_sm_ge86 in pytorch/aten/src/ATen/native/transformers/cuda/sdp_utils.h
if isSM86or89Device and head_dim in range(65, 129):
self.assertRaises(RuntimeError, lambda: F.scaled_dot_product_attention(query, key, value,
dropout_p=dropout_p, is_causal=is_causal))
return
Expand Down

0 comments on commit e9ebda2

Please sign in to comment.