Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Torch compile fusion backend prototype #209

Draft
wants to merge 31 commits into
base: upstream-main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
31 commits
Select commit Hold shift + click to select a range
b449e5d
Custom torch.compile backend prototype
bnellnm Apr 25, 2024
e9a62ff
add lowering_utils.py
bnellnm Apr 25, 2024
44153ea
torch.compile fusion backend prototype
bnellnm Apr 25, 2024
55d20b3
fix a mess of fusion pass bugs
bnellnm Apr 25, 2024
a2ba837
add fusion failure fallback exception
bnellnm Apr 25, 2024
2b0fc7d
add workaround for symbolic shape issue, fix other stuff
bnellnm Apr 29, 2024
a64f893
meta + signature generation
bnellnm Apr 30, 2024
cb686a1
refactor, use temporary files
bnellnm Apr 30, 2024
86381d8
comment
bnellnm Apr 30, 2024
692fd79
Merge branch 'upstream-main' into torch-compile-fusion-new
bnellnm May 1, 2024
e8a9b6b
wip
bnellnm May 1, 2024
1ce096f
wip registry
bnellnm May 1, 2024
bdd91ed
wip registry
bnellnm May 1, 2024
5f4bb6e
merge
bnellnm May 1, 2024
019910a
replace prints with logging
bnellnm May 2, 2024
5389a6e
refactoring + hacked up support for getitem
bnellnm May 3, 2024
835756b
cleanups + comments
bnellnm May 7, 2024
260cbf8
debugging print
bnellnm May 8, 2024
b8bc74b
move code cache to class scope
bnellnm May 9, 2024
9db9f46
wip
bnellnm May 10, 2024
193b7a6
remove use of split_module
bnellnm May 10, 2024
1188acb
handle dynamic dim wip
bnellnm May 11, 2024
e2f45bd
delete tensors that are no longer needed in c++
bnellnm May 12, 2024
abaab9b
turn down logging
bnellnm May 13, 2024
5bdc042
put symint rejection hack back in
bnellnm May 13, 2024
14791d6
fix 'memory leak'
bnellnm May 13, 2024
67e97ed
comments
bnellnm May 14, 2024
173e654
smarter slice translation
bnellnm May 14, 2024
ba1b6b1
tweaks
bnellnm May 17, 2024
b3b4d8a
add support for methods, add pattern matching
bnellnm May 20, 2024
a865b32
forgot rewrite file
bnellnm May 20, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
163 changes: 161 additions & 2 deletions vllm/attention/ops/paged_attn.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,8 @@
from vllm import _custom_ops as ops
from vllm.attention.ops.prefix_prefill import context_attention_fwd

from vllm.lowering_utils import vllm_lib, register_vllm_lowering

# Should be the same as PARTITION_SIZE in `paged_attention_v2_launcher`.
_PARTITION_SIZE = 512

Expand Down Expand Up @@ -110,7 +112,7 @@ def forward_decode(
and (max_num_partitions == 1 or num_seqs * num_heads > 512))
if use_v1:
# Run PagedAttention V1.
ops.paged_attention_v1(
torch.ops.vllm.paged_attention_v1(
output,
query,
key_cache,
Expand Down Expand Up @@ -139,7 +141,7 @@ def forward_decode(
device=output.device,
)
max_logits = torch.empty_like(exp_sums)
ops.paged_attention_v2(
torch.ops.vllm.paged_attention_v2(
output,
exp_sums,
max_logits,
Expand Down Expand Up @@ -213,3 +215,160 @@ def copy_blocks(
key_caches = [kv_cache[0] for kv_cache in kv_caches]
value_caches = [kv_cache[1] for kv_cache in kv_caches]
ops.copy_blocks(key_caches, value_caches, src_to_dists)

# needed for compile
vllm_lib.define(
"reshape_and_cache(Tensor key, Tensor value, Tensor key_cache, Tensor value_cache, Tensor slot_mapping, str dtype) -> (Tensor, Tensor)"
)


@torch.library.impl(vllm_lib, "reshape_and_cache", "Meta")
def _reshape_and_cache_meta(key, value, key_cache, value_cache, slot_mapping,
dtype):
return key_cache, value_cache


@torch.library.impl(vllm_lib, "reshape_and_cache", "CUDA")
def _reshape_and_cache(key, value, key_cache, value_cache, slot_mapping,
dtype):
cache_ops.reshape_and_cache(key, value, key_cache, value_cache,
slot_mapping, dtype)
return key_cache, value_cache


register_vllm_lowering(torch.ops.vllm.reshape_and_cache, [2, 3])

vllm_lib.define(
"paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, Tensor value_cache, int num_kv_heads, float scale, Tensor block_tables, Tensor context_lens, int block_size, SymInt max_context_len, Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> Tensor"
)
#vllm_lib.define(
# "paged_attention_v1(Tensor out, Tensor query, Tensor key_cache, Tensor value_cache, int num_kv_heads, float scale, Tensor block_tables, Tensor context_lens, int block_size, int max_context_len, Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> Tensor"
#)


@torch.library.impl(vllm_lib, "paged_attention_v1", "Meta")
def _paged_attention_v1_meta(
out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
):
return out


@torch.library.impl(vllm_lib, "paged_attention_v1", "CUDA")
def _paged_attention_v1(
out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
):
ops.paged_attention_v1(
out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
return out


register_vllm_lowering(torch.ops.vllm.paged_attention_v1, [0])

vllm_lib.define(
"paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits, Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache, int num_kv_heads, float scale, Tensor block_tables, Tensor context_lens, int block_size, SymInt max_context_len, Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> Tensor"
)
#vllm_lib.define(
# "paged_attention_v2(Tensor out, Tensor exp_sums, Tensor max_logits, Tensor tmp_out, Tensor query, Tensor key_cache, Tensor value_cache, int num_kv_heads, float scale, Tensor block_tables, Tensor context_lens, int block_size, int max_context_len, Tensor? alibi_slopes, str kv_cache_dtype, float kv_scale) -> Tensor"
#)


@torch.library.impl(vllm_lib, "paged_attention_v2", "Meta")
def _paged_attention_v2_meta(
out,
exp_sums,
max_logits,
tmp_out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
):
return out


@torch.library.impl(vllm_lib, "paged_attention_v2", "CUDA")
def _paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
):
ops.paged_attention_v2(
out,
exp_sums,
max_logits,
tmp_out,
query,
key_cache,
value_cache,
num_kv_heads,
scale,
block_tables,
context_lens,
block_size,
max_context_len,
alibi_slopes,
kv_cache_dtype,
kv_scale,
)
return out


register_vllm_lowering(torch.ops.vllm.paged_attention_v2, [0])
5 changes: 2 additions & 3 deletions vllm/distributed/parallel_state.py
Original file line number Diff line number Diff line change
Expand Up @@ -211,8 +211,7 @@ def get_pipeline_model_parallel_group():

def get_tensor_model_parallel_world_size():
"""Return world size for the tensor model parallel group."""
return torch.distributed.get_world_size(
group=get_tensor_model_parallel_group())
return get_tensor_model_parallel_group().size()


def get_pipeline_model_parallel_world_size():
Expand All @@ -223,7 +222,7 @@ def get_pipeline_model_parallel_world_size():

def get_tensor_model_parallel_rank():
"""Return my rank for the tensor model parallel group."""
return torch.distributed.get_rank(group=get_tensor_model_parallel_group())
return get_tensor_model_parallel_group().rank()


def get_pipeline_model_parallel_rank():
Expand Down
10 changes: 10 additions & 0 deletions vllm/ex/README.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,10 @@
- ex.py - the backend
- ex_builder.py - compiles/loads C++/CUDA torch modules
- testex*.py - various tests

TODO
----
0. fix stuff
- https://github.com/pytorch/pytorch/issues/108446
1. registration mechanism
2. backend code generator
51 changes: 51 additions & 0 deletions vllm/ex/code_cache.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
from typing import Callable, Optional
from vllm.logger import init_logger

logger = init_logger(__name__)

class CodeCache:
"""
The CodeCache is a simple map from mangled function names to Callables.

The CodeCache can be used to store the results of compiled code so that the
same Callable can be resued rather than needing to be recompiled.

Mangled function names should be generated with (or be compatible with) the
'utils.mangle_name' function.

Note: the CodeCache can be initialized with pre-compiled functions.
"""

def __init__(self):
self.cache = dict()

"""
Lookup a Callable for a function based on the 'mangled_name'. If the name
is not present in the cache, call the supplied 'generator' to create
the Callable to be associated with the 'mangled_name'. If the
generator fails for any reason a None will be stored in the map and
returned instead of a Callable. This will prevent any failed generators
from being called repeatedly.
"""
def lookup_or_create(
self,
mangled_name: str,
generator: Callable
) -> Optional[Callable]:
if not mangled_name in self.cache:
try:
self.cache[mangled_name] = generator()
except Exception as ex:
self.cache[mangled_name] = None
raise ex
return self.cache[mangled_name]

"""
Add a new entry to the cache. Return False if an entry with the
given name already exists.
"""
def add(mangled_name: str, fn: Optional[Callable]) -> bool:
if mangled_name in self.cache:
return False
self.cache[mangled_name] = fn
return True