Skip to content

Commit

Permalink
support ut of flash attention for more cases
Browse files Browse the repository at this point in the history
  • Loading branch information
runzhech committed Apr 29, 2024
1 parent b0e2ec6 commit 875bba2
Show file tree
Hide file tree
Showing 4 changed files with 261 additions and 51 deletions.
47 changes: 24 additions & 23 deletions paddle/phi/kernels/xpu/flash_attn_grad_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,8 @@ void FlashAttnGradKernel(const Context& ctx,
kvlod_vec.data(), static_cast<int64_t>(kvlod_vec.size()), nullptr};

// get seed offset
const int64_t* seed_offset_data = seed_offset.data<int64_t>();
const int32_t* seed_offset_data = seed_offset.data<int32_t>();

// template<typename T, typename TACCUM, typename TGEMM, typename TID = int>
// int mha_varlen_bwd(xdnn::Context* ctx, const T* dout, const T* q, const T*
// k, const T* v, const T* out, const TACCUM* softmax_lse, T* dq, T* dk, T*
Expand All @@ -106,28 +107,28 @@ void FlashAttnGradKernel(const Context& ctx,
// dv_maxptr = nullptr, const float* do_maxptr = nullptr);
int r = baidu::xpu::xfa::mha_varlen_bwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
dout_data, // dout
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
dq_data, // dq
dk_data, // dk
dv_data, // dv
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
seed_offset_data[0], // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_bwd");
#else
Expand Down
61 changes: 34 additions & 27 deletions paddle/phi/kernels/xpu/flash_attn_kernel.cc
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@

#include "paddle/phi/kernels/flash_attn_kernel.h"
#include "paddle/phi/backends/xpu/enforce_xpu.h"
#include "paddle/phi/core/enforce.h"
#include "paddle/phi/common/memory_utils.h"
#include "paddle/phi/core/kernel_registry.h"

#ifdef PADDLE_WITH_XPU_XHPC
Expand Down Expand Up @@ -237,12 +237,20 @@ void FlashAttnKernel(const Context& ctx,

// generate seed offset
seed_offset->Resize({2});
int64_t* seed_offset_data = ctx.template HostAlloc<int64_t>(seed_offset);
int32_t* seed_offset_data = ctx.template HostAlloc<int32_t>(seed_offset);
if (fixed_seed_offset.get_ptr()) {
const int64_t* fixed_seed_offset_data =
fixed_seed_offset.get_ptr()->data<int64_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
if ((fixed_seed_offset->place()).GetType() == phi::AllocationType::XPU) {
memory_utils::Copy(phi::CPUPlace(),
seed_offset_data,
fixed_seed_offset->place(),
fixed_seed_offset->data<int32_t>(),
sizeof(int32_t) * 2);
} else {
const int32_t* fixed_seed_offset_data =
fixed_seed_offset->data<int32_t>();
seed_offset_data[0] = fixed_seed_offset_data[0];
seed_offset_data[1] = fixed_seed_offset_data[1];
}
} else {
std::pair<uint64_t, uint64_t> seed_offset_pair;
uint64_t inc = batch_size * num_heads * 32;
Expand All @@ -253,8 +261,8 @@ void FlashAttnKernel(const Context& ctx,
auto* gen = ctx.GetGenerator();
seed_offset_pair = gen->IncrementOffset(inc);
}
seed_offset_data[0] = static_cast<int64_t>(seed_offset_pair.first);
seed_offset_data[1] = static_cast<int64_t>(seed_offset_pair.second);
seed_offset_data[0] = static_cast<int32_t>(seed_offset_pair.first);
seed_offset_data[1] = static_cast<int32_t>(seed_offset_pair.second);
}

// raw pointers
Expand All @@ -264,7 +272,6 @@ void FlashAttnKernel(const Context& ctx,
const XPUType* v_data = reinterpret_cast<const XPUType*>(v.data<T>());
XPUType* out_data = reinterpret_cast<XPUType*>(out->data<T>());
float* softmax_lse_data = softmax_lse->data<float>();

const float* bias_data = nullptr;
if (attn_mask.get_ptr() != nullptr) {
bias_data = attn_mask->data<float>();
Expand All @@ -281,24 +288,24 @@ void FlashAttnKernel(const Context& ctx,
// nullptr);
int r = baidu::xpu::xfa::mha_varlen_fwd<XPUType, float, tfloat32, int>(
ctx.x_context(),
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
static_cast<uint64_t>(seed_offset_data[0]), // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
q_data, // q
k_data, // k
v_data, // v
out_data, // out
softmax_lse_data, // softmax_lse
qlod, // lod_seqlens_q
kvlod, // lod_seqlens_k
seqlen_q, // max_seqlen_q
seqlen_k, // max_seqlen_k
num_heads, // head_num
num_heads_k, // head_num_k
head_size, // head_dim
1.0f / std::sqrt(head_size), // softmax_scale
dropout, // p_dropout
seed_offset_data[0], // seed
causal, // is_causal
nullptr, // attn_mask
bias_data // bias
);
PADDLE_ENFORCE_XDNN_SUCCESS(r, "mha_varlen_fwd");
#else
Expand Down
6 changes: 5 additions & 1 deletion test/xpu/op_test_xpu.py
Original file line number Diff line number Diff line change
Expand Up @@ -192,7 +192,11 @@ def check_grad_with_place(
if not core.is_bfloat16_supported(place):
return

if self.dtype == np.float16 or self.dtype == np.uint16:
if (
self.dtype == np.float16
or self.dtype == np.uint16
or user_defined_grads is not None
):
max_relative_error = 0.1
return super().check_grad_with_place(
place,
Expand Down
198 changes: 198 additions & 0 deletions test/xpu/test_flash_attention_v2_op_xpu.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,198 @@
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import random
import unittest

import numpy as np
from get_test_cover_info import (
XPUOpTestWrapper,
create_test_class,
get_xpu_op_support_types,
)
from op_test import convert_float_to_uint16
from op_test_xpu import XPUOpTest

import paddle
import paddle.nn.functional as F
from paddle.base import core


def get_triangle_upper_mask(x):
mask = paddle.full_like(x, -1e4)
mask.stop_gradient = True
mask = paddle.triu(mask, diagonal=1)
mask.stop_gradient = True
return mask.astype(np.float32)


def attention_naive(q, k, v, bias, is_causal=True):
origin_dtype = q.dtype
assert k.dtype == origin_dtype
assert v.dtype == origin_dtype
if q.dtype != paddle.float32:
q = paddle.cast(q, "float32")
k = paddle.cast(k, "float32")
v = paddle.cast(v, "float32")
# real calculation
qt = paddle.transpose(q, [0, 2, 1, 3])
kt = paddle.transpose(k, [0, 2, 1, 3])
vt = paddle.transpose(v, [0, 2, 1, 3])
scale = 1.0 / np.sqrt(q.shape[-1])
s = paddle.matmul(qt, paddle.transpose(kt, [0, 1, 3, 2]))
s = paddle.scale(s, scale)
if bias is not None:
s = s + bias
if is_causal:
mask = get_triangle_upper_mask(s)
s = s + mask
softmax_lse = paddle.logsumexp(s, axis=3)
p = F.softmax(s)
o = paddle.matmul(p, vt)
o = paddle.cast(o, np.float32)
o = paddle.transpose(o, [0, 2, 1, 3])
return o, softmax_lse


def is_flashattn_supported():
xpu_version = core.get_xpu_device_version(0)
if xpu_version != core.XPUVersion.XPU3:
return False
xhpc_version = paddle.version.xpu_xhpc()
if xhpc_version == 'False':
return False
return True


class XPUTestFlashAttentionOp(XPUOpTestWrapper):
def __init__(self):
self.op_name = "flash_attn"
self.use_dynamic_create_class = False

@unittest.skipIf(
not is_flashattn_supported(), "only available on XPU3 with XHPC"
)
class TestFlashAttentionOp(XPUOpTest):
def setUp(self):
self.op_type = "flash_attn"
self.tmp_seed = random.randint(1, 65536)
paddle.seed(self.tmp_seed)
self.init_dtype()
self.set_attrs()
self.set_shape()
self.init_data()
self.tolerance = 5e-4

def init_dtype(self):
self.dtype = self.in_type

def set_shape(self):
# [b, l, h, d]
self.q_shape = [1, 128, 2, 32]
self.k_shape = [1, 128, 2, 32]
self.v_shape = [1, 128, 2, 32]
# [b, h, l, l]
self.bias_shape = [1, 2, 128, 128]

def set_attrs(self):
self.is_causal = True
self.with_bias = False

def init_data(self):
q = np.random.random(self.q_shape)
k = np.random.random(self.k_shape)
v = np.random.random(self.v_shape)
q_ = paddle.to_tensor(q, stop_gradient=False)
k_ = paddle.to_tensor(k, stop_gradient=False)
v_ = paddle.to_tensor(v, stop_gradient=False)
# fixed the seed & offset to pass the check of seed_offset
fixed_seed_offset = paddle.to_tensor(
np.array([self.tmp_seed, 0]).astype(np.int32)
)
self.inputs = {
"q": convert_float_to_uint16(q)
if self.dtype == np.uint16
else q.astype(self.dtype),
"k": convert_float_to_uint16(k)
if self.dtype == np.uint16
else k.astype(self.dtype),
"v": convert_float_to_uint16(v)
if self.dtype == np.uint16
else v.astype(self.dtype),
"fixed_seed_offset": fixed_seed_offset,
}
bias_ = None
if self.with_bias:
bias = np.random.random(self.bias_shape).astype(np.float32)
self.inputs["attn_mask"] = bias
bias_ = paddle.to_tensor(bias, stop_gradient=True)

out, softmax_lse = attention_naive(
q_, k_, v_, bias=bias_, is_causal=self.is_causal
)
out.backward()
self.dq = q_.grad.numpy()
self.dk = k_.grad.numpy()
self.dv = v_.grad.numpy()
self.dout = paddle.ones_like(out, dtype=self.dtype)
self.attrs = {
'dropout': 0.0,
'causal': self.is_causal,
'return_softmax': False,
'rng_name': '',
}
softmax_lse = softmax_lse.numpy()
self.outputs = {
"out": convert_float_to_uint16(out)
if self.dtype == np.uint16
else out.astype(self.dtype),
"softmax": np.array([]), # not used
"softmax_lse": softmax_lse,
"seed_offset": fixed_seed_offset,
}

def test_check_output(self):
self.check_output_with_place(
paddle.XPUPlace(0), atol=self.tolerance, rtol=self.tolerance
)

def test_check_grad(self):
self.check_grad(
['q', 'k', 'v'],
'out',
user_defined_grads=[self.dq, self.dk, self.dv],
user_defined_grad_outputs=self.dout,
numeric_grad_delta=self.tolerance,
max_relative_error=self.tolerance,
)

# class TestFlashAttentionOp2_with_bias(TestFlashAttentionOp):
# def set_attrs(self):
# self.is_causal = True
# self.with_bias = True

# class TestFlashAttentionOp3_uncausal_with_bias(TestFlashAttentionOp): //WIP

# def set_attrs(self):
# self.is_causal = False
# self.with_bias = True


support_types = get_xpu_op_support_types("flash_attn")
for stype in support_types:
create_test_class(globals(), XPUTestFlashAttentionOp, stype)

if __name__ == '__main__':
paddle.disable_static()
unittest.main()

0 comments on commit 875bba2

Please sign in to comment.