Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] FLAN-T5 integration #194

Open
wants to merge 72 commits into
base: baseline_commit
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
72 commits
Select commit Hold shift + click to select a range
dd82ba3
t5-small
Feb 18, 2024
f2fd579
fix
js8544 Feb 29, 2024
2fb6905
lint
js8544 Feb 29, 2024
be58c3b
T5 enc/dec example file; linting/formatting
afeldman-nm Mar 1, 2024
70837fd
native/vllm t5 comparison test
afeldman-nm Mar 1, 2024
42a6e2b
merged upstream-main into enc_dec_t5
afeldman-nm Mar 1, 2024
e3fd30d
Merge branch 'upstream-main' into enc_dec_t5
afeldman-nm Mar 2, 2024
db726e6
Merge pull request #1 from afeldman-nm/enc_dec_t5
js8544 Mar 2, 2024
43e920e
remove debug print statements
afeldman-nm Mar 2, 2024
431f014
silence warning; legacy=False for tokenizer; lint/format
afeldman-nm Mar 2, 2024
37fcf99
Merge branch 'js8544_enc_dec_t5' into enc_dec_t5
afeldman-nm Mar 2, 2024
4bf056b
Merge pull request #2 from afeldman-nm/enc_dec_t5
js8544 Mar 2, 2024
8a5060f
fix _make_tensor_with_pad args change which broke decoder scenario
afeldman-nm Mar 5, 2024
29d6f44
fixed bug caused by non-handling of self.model_config is None in mode…
afeldman-nm Mar 5, 2024
a4950ba
remove commented-out print statements
afeldman-nm Mar 5, 2024
9c03760
small cleanup
afeldman-nm Mar 5, 2024
9f20ccf
Merge pull request #3 from afeldman-nm/enc_dec_t5
js8544 Mar 6, 2024
6d6dccd
arg naming fix
afeldman-nm Mar 7, 2024
7035178
Merge branch 'js8544_enc_dec_t5' into enc_dec_t5
afeldman-nm Mar 12, 2024
dbec357
fixed attention_kernels.cu merge conflict; questions about ROCM
afeldman-nm Mar 12, 2024
4b2a121
llm_engine.py conflict resolution; removed prefix caching code; Seque…
afeldman-nm Mar 12, 2024
a93c17d
actually updated Sequence constructor to take i_encoder_decoder, eos_…
afeldman-nm Mar 12, 2024
a62c3af
xformers.py accept incoming changes; replace paged_attention function…
afeldman-nm Mar 12, 2024
c31921f
saved changed to xformers woops
afeldman-nm Mar 12, 2024
0c78be9
attempt at fixing model_runner conflicts related to encoder/decoder &…
afeldman-nm Mar 12, 2024
e25e6b8
encoder/decoder + prefix caching not supported; moved check from llm.…
afeldman-nm Mar 12, 2024
7f70d76
refactoring, including: moved enc_dec_attention.py into vllm/model_ex…
afeldman-nm Mar 12, 2024
36c8291
existing regressions pass (yay) but encoder/decoder example fails
afeldman-nm Mar 12, 2024
08f268a
fixed encoder/decoder reshape and cache bug, but paged attention call…
afeldman-nm Mar 12, 2024
b9b0600
augmented paged attention with context_lens, max_context_len, block_t…
afeldman-nm Mar 12, 2024
63e9dca
linting/formatting fixes
afeldman-nm Mar 12, 2024
4d7e5a8
Merge branch 'upstream-main' into enc_dec_t5_merge_upstream2
afeldman-nm Mar 12, 2024
bb7a219
Merge branch 'upstream-main' into enc_dec_t5_merge_upstream2
afeldman-nm Mar 16, 2024
0b60121
fixed bug introduced during formatting
afeldman-nm Mar 16, 2024
d44257e
fixed example
afeldman-nm Mar 16, 2024
19c5c4b
Merge branch 'enc_dec_t5' into enc_dec_t5_merge_upstream2
afeldman-nm Mar 16, 2024
c2f97b6
merged upstream
afeldman-nm Mar 19, 2024
0536ff5
rolled back some encoder/decoder changes
afeldman-nm Mar 19, 2024
7d4972c
merged in upstream-main
afeldman-nm Mar 22, 2024
23a5da5
added cross_block_tables to SequenceGroupMetadata
afeldman-nm Mar 22, 2024
e32fb9c
SequenceGroupMetadata: added cross_seq_data; optional along with cros…
afeldman-nm Mar 22, 2024
ae1c368
added block manager allocation of cross sequence block_tables
afeldman-nm Mar 22, 2024
691c2c1
scheduler schedule() support cross block-tables and cross sequences, …
afeldman-nm Mar 22, 2024
e240eb4
LLMEngine can build a sequencegroup with cross sequences
afeldman-nm Mar 22, 2024
cbfba8e
t5 Sampler does not pass vocab size to constructor; input_metadata.pr…
afeldman-nm Mar 22, 2024
501551c
add_request now correctly swaps decoder_prompt, prompt in encoder/de…
afeldman-nm Mar 22, 2024
08435e4
Added cross_input_metadata field to InputMetadata
afeldman-nm Mar 22, 2024
6e459a2
wip multi blocktable
afeldman-nm Mar 25, 2024
8e1ca33
wip
afeldman-nm Mar 25, 2024
e097732
plumbing dummy input metadata structures into model
afeldman-nm Mar 25, 2024
2a44585
plumbed encoder/decoder input metadata all the way into t5
afeldman-nm Mar 25, 2024
91a4608
first pass at T5 encoder support
afeldman-nm Mar 26, 2024
d0c5e36
inefficient but effective & Attention-wrapper-compatible implementati…
afeldman-nm Mar 27, 2024
3737d5b
wip cross-attention
afeldman-nm Mar 28, 2024
38946ed
first pass at enc/dec support that runs e2e but doesn't produce corre…
afeldman-nm Apr 1, 2024
3c39f55
to pass regression tests: removed debug prints
afeldman-nm Apr 1, 2024
4ec2fde
wip vllm, examples => fp32
afeldman-nm Apr 1, 2024
38f55ed
works on bsz = 1
afeldman-nm Apr 1, 2024
1aedc80
intermediate activations for prompt_run look right! Decoded token loo…
afeldman-nm Apr 1, 2024
c1258b4
wip
afeldman-nm Apr 2, 2024
0af1022
passing with t5-small
afeldman-nm Apr 3, 2024
9e8d234
vLLM T5 matches nativegit status! fixes: decode-phase cross-input-met…
afeldman-nm Apr 4, 2024
f5242a0
refactoring out print statements
afeldman-nm Apr 4, 2024
de0fd31
fix to pass regression tests
afeldman-nm Apr 4, 2024
5a67647
WIP google/flan-t5-xxxx
afeldman-nm Apr 4, 2024
ed05d47
removed print statement
afeldman-nm Apr 4, 2024
d5a8b92
batched enc/dec example
afeldman-nm Apr 10, 2024
f555f5d
wip, trying prompt padding
afeldman-nm Apr 12, 2024
2c12b44
bs >1 prefill works
afeldman-nm Apr 17, 2024
dba02b2
small change to examples
afeldman-nm Apr 17, 2024
db201b6
fix to support case where num prompts != 2
afeldman-nm Apr 17, 2024
ead7c82
set up (failing) flan-t5 test
afeldman-nm Apr 17, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
4 changes: 4 additions & 0 deletions .clang-format
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
# Use the Google style in this project.
BasedOnStyle: Google

ColumnLimit: 120
7 changes: 4 additions & 3 deletions benchmarks/backend_request_func.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,7 +110,8 @@ async def async_request_vllm(
output.ttft = ttft
output.latency = time.perf_counter() - st

# When streaming, '\0' is appended to the end of response.
# When streaming, '\0' is appended
# to the end of the response.
body = data.decode("utf-8").strip("\0")
output.generated_text = json.loads(
body)["text"][0][len(request_func_input.prompt):]
Expand Down Expand Up @@ -192,8 +193,8 @@ async def async_request_deepspeed_mii(
output = RequestFuncOutput()
output.prompt_len = request_func_input.prompt_len

# DeepSpeed-MII doesn't support streaming as of Jan 28 2024,
# will use 0 as placeholder.
# DeepSpeed-MII doesn't support streaming
# as of Jan 28 2024, will use 0 as placeholder.
# https://github.com/microsoft/DeepSpeed-MII/pull/311
output.ttft = 0

Expand Down
8 changes: 4 additions & 4 deletions benchmarks/benchmark_serving.py
Original file line number Diff line number Diff line change
Expand Up @@ -293,9 +293,8 @@ def main(args: argparse.Namespace):

# Save to file
base_model_id = model_id.split("/")[-1]
file_name = (
f"{backend}-{args.request_rate}qps-{base_model_id}-{current_dt}.json"
)
file_name = f"{backend}-{args.request_rate}qps-" \
f"{base_model_id}-{current_dt}.json"
with open(file_name, "w") as outfile:
json.dump(result_json, outfile)

Expand Down Expand Up @@ -343,7 +342,8 @@ def main(args: argparse.Namespace):
"--tokenizer",
type=str,
help=
"Name or path of the tokenizer, if not using the default tokenizer.",
"Name or path of the tokenizer, if not " \
"using the default model tokenizer.",
)
parser.add_argument(
"--best-of",
Expand Down
16 changes: 12 additions & 4 deletions benchmarks/benchmark_throughput.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,8 @@
from typing import List, Optional, Tuple

import torch
from transformers import (AutoModelForCausalLM, AutoTokenizer,
PreTrainedTokenizerBase)
from transformers import (AutoModelForCausalLM, T5ForConditionalGeneration,
AutoTokenizer, PreTrainedTokenizerBase)
from tqdm import tqdm


Expand Down Expand Up @@ -125,8 +125,16 @@ def run_hf(
trust_remote_code: bool,
) -> float:
assert not use_beam_search
llm = AutoModelForCausalLM.from_pretrained(
model, torch_dtype=torch.float16, trust_remote_code=trust_remote_code)
if "t5" in model:
llm = T5ForConditionalGeneration.from_pretrained(
model,
torch_dtype=torch.float16,
trust_remote_code=trust_remote_code)
else:
llm = AutoModelForCausalLM.from_pretrained(
model,
torch_dtype=torch.float16,
trust_remote_code=trust_remote_code)
if llm.config.model_type == "llama":
# To enable padding in the HF backend.
tokenizer.pad_token = tokenizer.eos_token
Expand Down
58 changes: 40 additions & 18 deletions csrc/attention/attention_kernels.cu
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
/*
* Adapted from https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Adapted from
* https://github.com/NVIDIA/FasterTransformer/blob/release/v5.3_tag/src/fastertransformer/kernels/decoder_masked_multihead_attention/decoder_masked_multihead_attention_template.hpp
* Copyright (c) 2023, The vLLM team.
* Copyright (c) 2020-2023, NVIDIA CORPORATION. All rights reserved.
*
Expand All @@ -16,9 +17,10 @@
* limitations under the License.
*/

#include <torch/extension.h>
#include <ATen/cuda/CUDAContext.h>
#include <c10/cuda/CUDAGuard.h>
#include <stdio.h>
#include <torch/extension.h>

#include "attention_dtypes.h"
#include "attention_utils.cuh"
Expand All @@ -30,12 +32,12 @@

#define MAX(a, b) ((a) > (b) ? (a) : (b))
#define MIN(a, b) ((a) < (b) ? (a) : (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b) - 1) / (b))
#define DIVIDE_ROUND_UP(a, b) (((a) + (b)-1) / (b))

namespace vllm {

// Utility function for attention softmax.
template<int NUM_WARPS>
template <int NUM_WARPS>
inline __device__ float block_sum(float* red_smem, float sum) {
// Decompose the thread index into warp / lane.
int warp = threadIdx.x / WARP_SIZE;
Expand Down Expand Up @@ -93,6 +95,7 @@ __device__ void paged_attention_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, max_seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
Expand Down Expand Up @@ -133,6 +136,10 @@ __device__ void paged_attention_kernel(
const int num_queries_per_kv = num_heads / num_kv_heads;
const int kv_head_idx = head_idx / num_queries_per_kv;
const float alibi_slope = alibi_slopes == nullptr ? 0.f : alibi_slopes[head_idx];
const float* custom_bias_vec = custom_bias == nullptr
? nullptr
: custom_bias + seq_idx * num_kv_heads * num_context_blocks * BLOCK_SIZE +
kv_head_idx * num_context_blocks * BLOCK_SIZE;

// A vector type to store a part of a key or a query.
// The vector size is configured in such a way that the threads in a thread group
Expand Down Expand Up @@ -224,8 +231,10 @@ __device__ void paged_attention_kernel(
// Compute dot product.
// This includes a reduction across the threads in the same thread group.
float qk = scale * Qk_dot<scalar_t, THREAD_GROUP_SIZE>::dot(q_vecs[thread_group_offset], k_vecs);
// Add the ALiBi bias if slopes are given.
qk += (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1) : 0;
// Add the custom or ALiBi bias if given.
qk += (custom_bias_vec != nullptr) ? custom_bias_vec[token_idx]
: (alibi_slope != 0) ? alibi_slope * (token_idx - context_len + 1)
: 0;

if (thread_group_offset == 0) {
// Store the partial reductions to shared memory.
Expand Down Expand Up @@ -435,13 +444,14 @@ __global__ void paged_attention_v1_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE>(
/* exp_sums */ nullptr, /* max_logits */ nullptr,
out, q, k_cache, v_cache, num_kv_heads, scale, block_tables, context_lens,
max_num_blocks_per_seq, alibi_slopes, q_stride, kv_block_stride, kv_head_stride);
max_num_blocks_per_seq, alibi_slopes, custom_bias, q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs, max_num_partitions).
Expand All @@ -466,13 +476,14 @@ __global__ void paged_attention_v2_kernel(
const int* __restrict__ context_lens, // [num_seqs]
const int max_num_blocks_per_seq,
const float* __restrict__ alibi_slopes, // [num_heads]
const float* __restrict__ custom_bias, // [num_seqs, num_heads, 1, seq_len]
const int q_stride,
const int kv_block_stride,
const int kv_head_stride) {
paged_attention_kernel<scalar_t, cache_t, HEAD_SIZE, BLOCK_SIZE, NUM_THREADS, IS_FP8_E5M2_KV_CACHE, PARTITION_SIZE>(
exp_sums, max_logits, tmp_out, q, k_cache, v_cache, num_kv_heads, scale,
block_tables, context_lens, max_num_blocks_per_seq, alibi_slopes,
q_stride, kv_block_stride, kv_head_stride);
custom_bias, q_stride, kv_block_stride, kv_head_stride);
}

// Grid: (num_heads, num_seqs).
Expand Down Expand Up @@ -592,6 +603,7 @@ __global__ void paged_attention_v2_reduce_kernel(
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
custom_bias_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride);
Expand All @@ -613,7 +625,8 @@ void paged_attention_v1_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -626,9 +639,11 @@ void paged_attention_v1_launcher(
assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* alibi_slopes_ptr =
alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;

// NOTE: alibi_slopes is optional.
const float* custom_bias_ptr = custom_bias ? reinterpret_cast<const float*>(custom_bias.value().data_ptr()) : nullptr;

T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
T* query_ptr = reinterpret_cast<T*>(query.data_ptr());
Expand Down Expand Up @@ -688,7 +703,8 @@ void paged_attention_v1_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
custom_bias);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -720,6 +736,7 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
Expand Down Expand Up @@ -762,6 +779,7 @@ void paged_attention_v1(
context_lens_ptr, \
max_num_blocks_per_seq, \
alibi_slopes_ptr, \
custom_bias_ptr, \
q_stride, \
kv_block_stride, \
kv_head_stride); \
Expand Down Expand Up @@ -794,7 +812,8 @@ void paged_attention_v2_launcher(
torch::Tensor& block_tables,
torch::Tensor& context_lens,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes) {
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias) {
int num_seqs = query.size(0);
int num_heads = query.size(1);
int head_size = query.size(2);
Expand All @@ -807,9 +826,10 @@ void paged_attention_v2_launcher(
assert(head_size % thread_group_size == 0);

// NOTE: alibi_slopes is optional.
const float* alibi_slopes_ptr = alibi_slopes ?
reinterpret_cast<const float*>(alibi_slopes.value().data_ptr())
: nullptr;
const float* alibi_slopes_ptr =
alibi_slopes ? reinterpret_cast<const float*>(alibi_slopes.value().data_ptr()) : nullptr;

const float* custom_bias_ptr = custom_bias ? reinterpret_cast<const float*>(custom_bias.value().data_ptr()) : nullptr;

T* out_ptr = reinterpret_cast<T*>(out.data_ptr());
float* exp_sums_ptr = reinterpret_cast<float*>(exp_sums.data_ptr());
Expand Down Expand Up @@ -878,7 +898,8 @@ void paged_attention_v2_launcher(
block_tables, \
context_lens, \
max_context_len, \
alibi_slopes);
alibi_slopes, \
custom_bias);

// NOTE(woosuk): To reduce the compilation time, we omitted block sizes
// 1, 2, 4, 64, 128, 256.
Expand Down Expand Up @@ -913,6 +934,7 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype) {
if (kv_cache_dtype == "auto") {
if (query.dtype() == at::ScalarType::Float) {
Expand Down
2 changes: 2 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -14,6 +14,7 @@ void paged_attention_v1(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype);

void paged_attention_v2(
Expand All @@ -31,6 +32,7 @@ void paged_attention_v2(
int block_size,
int max_context_len,
const c10::optional<torch::Tensor>& alibi_slopes,
const c10::optional<torch::Tensor>& custom_bias,
const std::string& kv_cache_dtype);

void rms_norm(
Expand Down
94 changes: 94 additions & 0 deletions examples/offline_inference_enc_dec.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
'''
Affirm T5 model outputs match between vLLM and native PyTorch

Scenarios:
* t5-small, t5-large
* float16, float32, bfloat16, bfloat32
* Custom prompts & num. prompts

Output: for several prompts, compare native PyTorch & vLLM prompt completions
'''
import warnings
import torch
from vllm import LLM, SamplingParams
from transformers import T5Tokenizer, T5ForConditionalGeneration

warnings.filterwarnings("ignore",
category=UserWarning,
module="transformers.generation.utils.*")

hf_model_id = "google/flan-t5-small" # t5-small
dtype = "float32"
prompts = [
#"Who are you?",
#"Who are you?",
#"How do",
#"Who aren't you?",
#"Who aren't you?<pad><pad><pad><pad>", #
"Who are you? Write a very long response.",
]

dtype_obj = getattr(torch, dtype)

# Native PyTorch test

# - Model and tokenizer initialization
tokenizer = T5Tokenizer.from_pretrained(hf_model_id, legacy=False)
model:T5ForConditionalGeneration = T5ForConditionalGeneration.from_pretrained(hf_model_id).to(
dtype=dtype_obj)

# - Assume 'dtype' is already defined, e.g., dtype=torch.float32
# - Tokenizing the prompts list with specified data type
input_ids = tokenizer(prompts,
return_tensors="pt",
padding=True,
truncation=True).input_ids

# - If using GPU, also send input_ids to the same device as the model
if torch.cuda.is_available():
model = model.cuda() # Move model to GPU
input_ids = input_ids.cuda() # Move input_ids to GPU

# - Max token count for both native and vLLM test
max_tokens = 512

# - Generating outputs for all tokenized prompts
native_outputs = model.generate(input_ids,max_length = max_tokens).cpu()

# vLLM test
model: LLM = LLM(hf_model_id,
enforce_eager=True,
dtype=dtype,
gpu_memory_utilization=0.5)

sampling_params = SamplingParams(max_tokens=max_tokens, temperature=0)

vllm_outputs = model.generate(
prompts,
sampling_params=sampling_params
)

print(native_outputs)
print(vllm_outputs)

# Print native & vLLM outputs
i = 0
for native_output, vllm_output in zip(native_outputs, vllm_outputs):
print(f"Prompt {i}:")

prompt = prompts[i] # Get the corresponding prompt for this output
native_generated_text = tokenizer.decode(
native_output, skip_special_tokens=True) # Decode the generated text
vllm_generated_text = vllm_output.outputs[0].text
print(
f"- Prompt: {prompt!r}, Native PyTorch generated text: " \
f"{native_generated_text!r}, " \
f"vLLM generated text: {vllm_generated_text!r}"
)

print("- Asserting textual match")
#assert native_generated_text == vllm_generated_text
print("- Asserting token match")
#assert native_output[1:-1].tolist() == vllm_output.outputs[0].token_ids[:-1]

i += 1