Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[WIP] afeldman-nm/encoder decoder #22

Closed
82 changes: 82 additions & 0 deletions examples/offline_inference_english_sr.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,82 @@
# Based on this tutorial
# https://huggingface.co/docs/transformers/model_doc/whisper

# TODO: install huggingface 'transformers', 'datasets', 'soundfile', 'librosa' in requirements.txt or work around it
# to obtain audio dataset
from datasets import load_dataset
from transformers import WhisperProcessor, WhisperForConditionalGeneration
from vllm import LLM, SamplingParams

# TODO: vLLM audio frontend performs audio tokenization
# This code is a standin for a new audio frontend
ds = load_dataset("hf-internal-testing/librispeech_asr_dummy", "clean", split="validation")
processor = WhisperProcessor.from_pretrained("openai/whisper-tiny")
audio_sample = ds[0]["audio"]
waveform = audio_sample["array"] # (93680,) floating-point values
processor_params = {
"sampling_rate": audio_sample["sampling_rate"] # 16KHz
}

# (1, 80, 3000)
input_features = processor(
waveform, sampling_rate=processor_params["sampling_rate"], return_tensors="pt"
).input_features

# Transcription decoder text token sampling params object.
sampling_params = SamplingParams(temperature=0.0)

# Create an LLM.
#
# Equivalent to:
#
# processor = WhisperProcessor.from_pretrained("openai/whisper-tiny.en")
# model = WhisperForConditionalGeneration.from_pretrained("openai/whisper-tiny.en")
#
# The HuggingFace model identifier is used to pull in (1) the model pretrained weights & config and
# (2) the WhisperProcessor trained weights associated with the model
#

#
llm = LLM(model="openai/whisper-tiny")



# Encode audio
#
# This is a change from how decoder-only LLM works *and* how HF transformers whisper workflow operates:
# - For encoder/decoder (E/D) models LLM.generate() is equivalent to
#
# input_features = processor(
# waveform, sampling_rate=processor_params["sampling_rate"], return_tensors="pt"
# ).input_features
#
# predicted_ids = model.generate(input_features)
#
# i.e. LLM.generate() facilitates encoding.
#
# vLLM convention appears to be to abstract tokenization/preprocessing behind .generate()
#
predicted_ids = llm.generate(prompt_token_ids=input_features)

# Decoder token ids to transcription
#
# .batch_decode() is not yet a method of LLM. Its proposed function signature matches LLM.generate(tokens, sampling params)
# but with added kwargs to support typical use-cases, i.e. "skip_special_tokens".
#
# Equivalent to
#
# transcription = processor.batch_decode(predicted_ids, skip_special_tokens=True)
#
# The HuggingFace transcriptions output is a list of transcription strings.
#
# vLLM typical decoder-only behavior is that each element of 'output' list is a data structure with
# output.prompt, and output.outputs[...].text. So most likely we would wrap decoder outputs to respect
# this data structure.
#
transcriptions = processor.batch_decode(predicted_ids, sampling_params, skip_special_tokens=True)

# Print the outputs.
for transcription in transcriptions:
# transcription_predicted_ids = transcription.predicted_ids
transcription_text = transcription.outputs[0].text
print(f"Transcription: {transcription_text!r}")
87 changes: 87 additions & 0 deletions vllm/model_executor/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,93 @@
# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

class EncoderAttention(nn.Module):
"""Layer with encoder-style attention.
"""

def __init__(
self,
num_heads: int,
head_size: int,
scale: float,
num_kv_heads: Optional[int] = None,
) -> None:
super().__init__()
self.num_heads = num_heads
self.head_size = head_size
self.scale = float(scale)
self.num_kv_heads = num_heads if num_kv_heads is None else num_kv_heads

assert self.num_heads % self.num_kv_heads == 0
self.num_queries_per_kv = self.num_heads // self.num_kv_heads

if self.head_size not in _SUPPORTED_HEAD_SIZES:
raise ValueError(f"head_size ({self.head_size}) is not supported. "
f"Supported head sizes: {_SUPPORTED_HEAD_SIZES}.")

def forward(
self,
query: torch.Tensor,
key: torch.Tensor,
value: torch.Tensor,
key_cache: Optional[torch.Tensor],
value_cache: Optional[torch.Tensor],
input_metadata: InputMetadata,
) -> torch.Tensor:
"""EncoderAttention forward pass.
"""
batch_size, seq_len, hidden_size = query.shape
# Reshape the query, key, and value tensors.
query = query.view(-1, self.num_heads, self.head_size)
key = key.view(-1, self.num_kv_heads, self.head_size)
value = value.view(-1, self.num_kv_heads, self.head_size)

# Reshape the keys and values and store them in the cache.
# If key_cache and value_cache are not provided, the new key and value
# vectors will not be cached. This happens during the initial memory
# profiling run.
if key_cache is not None and value_cache is not None:
cache_ops.reshape_and_cache(
key,
value,
key_cache,
value_cache,
input_metadata.slot_mapping.flatten(),
input_metadata.kv_cache_dtype,
)

# As of Nov 2023, xformers only supports MHA. For MQA/GQA,
# project the key and value tensors to the desired number of
# heads.
# TODO(woosuk): Use MQA/GQA kernels for higher performance.
query = query.view(query.shape[0], self.num_kv_heads,
self.num_queries_per_kv, query.shape[-1])
key = key[:, :,
None, :].expand(key.shape[0], self.num_kv_heads,
self.num_queries_per_kv,
key.shape[-1])
value = value[:, :, None, :].expand(value.shape[0],
self.num_kv_heads,
self.num_queries_per_kv,
value.shape[-1])

# TODO (afeldman-nm): right now, assume prefix-enabled attention does not need to be supported by encoder. Confirm later.
# TODO (afeldman-nm): right now, assume no support for masked attention, ALiBi, etc. is required for encoder. Correct later

out = xops.memory_efficient_attention_forward(
query,
key,
value,
attn_bias=input_metadata.attn_bias,
p=0.0,
scale=self.scale,
op=xops.fmha.MemoryEfficientAttentionFlashAttentionOp[0] if
(is_hip()) else None,
)
output = out.view_as(query)

return output


class PagedAttention(nn.Module):
"""MHA/MQA/GQA layer with PagedAttention.
Expand Down
1 change: 1 addition & 0 deletions vllm/model_executor/models/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,7 @@
"Qwen2ForCausalLM": ("qwen2", "Qwen2ForCausalLM"),
"RWForCausalLM": ("falcon", "FalconForCausalLM"),
"StableLMEpochForCausalLM": ("stablelm", "StablelmForCausalLM"),
"WhisperForConditionalGeneration": ("whisper","WhisperForConditionalGeneration")
}

# Models not supported by ROCm.
Expand Down