Skip to content

Commit

Permalink
Merge branch 'main' into release/v0.4.0
Browse files Browse the repository at this point in the history
  • Loading branch information
dakinggg committed Nov 21, 2023
2 parents 11df92c + 1793c36 commit 4638092
Show file tree
Hide file tree
Showing 2 changed files with 11 additions and 4 deletions.
8 changes: 4 additions & 4 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,11 +296,11 @@ def flash_attn_fn(
# we use .view to modify {key, value}_unpad appropriately

key_unpad = repeat_kv_for_gqa(
key_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
key_unpad.view(1, key_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(key_unpad.size(0), n_heads, -1)
value_unpad = repeat_kv_for_gqa(
value_unpad.view(batch_size, seqlen, kv_n_heads, -1),
n_heads // kv_n_heads).view(batch_size * seqlen, n_heads, -1)
value_unpad.view(1, value_unpad.size(0), kv_n_heads, -1),
n_heads // kv_n_heads).view(value_unpad.size(0), n_heads, -1)

dropout_p = dropout_p if training else 0.0

Expand Down
7 changes: 7 additions & 0 deletions tests/test_flash_triton_torch.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,12 +54,14 @@ def allclose_helper(t0: torch.Tensor,
@pytest.mark.parametrize(
'attn_type',
['multihead_attention', 'multiquery_attention', 'grouped_query_attention'])
@pytest.mark.parametrize('pad_attention_mask', [True, False])
def test_attn_impl(attn_impl_0: str,
attn_impl_1: str,
clip_qkv: bool,
qk_ln: bool,
pos_emb_config: dict,
attn_type: str,
pad_attention_mask: bool,
device: str = 'cuda'):
"""Compare all attn impl with each other.
Expand Down Expand Up @@ -98,6 +100,11 @@ def test_attn_impl(attn_impl_0: str,

attention_mask = torch.ones(n, s).to(device).bool()

if pad_attention_mask:
# zero out the last third of the attention mask
# to simulate padding
attention_mask[:, :s // 3] = 0

def gen_bias(attn_impl: str):
causal = True
attn_bias = None
Expand Down

0 comments on commit 4638092

Please sign in to comment.