Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] add fast broadcast for tensor dict #4757

Open
wants to merge 31 commits into
base: main
Choose a base branch
from

Conversation

youkaichao
Copy link
Member

An ongoing effort of #4440 .

Reduce the number of broadcast from 2 to 1.

Broadcast time (before): 0.38772106170654297ms
Broadcast time (after): 0.128173828125ms

TODO:

  • improve the broadcast for prepare input in the same way.

@youkaichao
Copy link
Member Author

The TensorMetadata is not good for serialization:

from vllm.distributed.communication_op import TensorMetadata
import torch
d = TensorMetadata("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 120
import pickletools
pickletools.dis(s)

output:

    0: \x80 PROTO      4
    2: \x95 FRAME      109
   11: \x8c SHORT_BINUNICODE 'vllm.distributed.communication_op'
   46: \x94 MEMOIZE    (as 0)
   47: \x8c SHORT_BINUNICODE 'TensorMetadata'
   63: \x94 MEMOIZE    (as 1)
   64: \x93 STACK_GLOBAL
   65: \x94 MEMOIZE    (as 2)
   66: \x8c SHORT_BINUNICODE 'cuda'
   72: \x94 MEMOIZE    (as 3)
   73: \x8c SHORT_BINUNICODE 'torch'
   80: \x94 MEMOIZE    (as 4)
   81: \x8c SHORT_BINUNICODE 'float32'
   90: \x94 MEMOIZE    (as 5)
   91: \x93 STACK_GLOBAL
   92: \x94 MEMOIZE    (as 6)
   93: \x8c SHORT_BINUNICODE 'torch'
  100: \x94 MEMOIZE    (as 7)
  101: \x8c SHORT_BINUNICODE 'Size'
  107: \x94 MEMOIZE    (as 8)
  108: \x93 STACK_GLOBAL
  109: \x94 MEMOIZE    (as 9)
  110: )    EMPTY_TUPLE
  111: \x85 TUPLE1
  112: \x94 MEMOIZE    (as 10)
  113: R    REDUCE
  114: \x94 MEMOIZE    (as 11)
  115: \x87 TUPLE3
  116: \x94 MEMOIZE    (as 12)
  117: \x81 NEWOBJ
  118: \x94 MEMOIZE    (as 13)
  119: .    STOP
highest protocol among opcodes = 4

Each single TensorMetadata takes 120 bytes.

@youkaichao
Copy link
Member Author

After a8d1d3a, the serialization size is reduced by more than a half (120 bytes to 52 bytes):

from vllm import TensorMeta
import torch
d = TensorMeta("cuda", torch.float32, torch.Size([]))
import pickle
s = pickle.dumps(d)
len(s) # 52
import pickletools
pickletools.dis(s)

output:

    0: \x80 PROTO      4
    2: \x95 FRAME      41
   11: \x8c SHORT_BINUNICODE 'vllm'
   17: \x94 MEMOIZE    (as 0)
   18: \x8c SHORT_BINUNICODE 'TensorMeta'
   30: \x94 MEMOIZE    (as 1)
   31: \x93 STACK_GLOBAL
   32: \x94 MEMOIZE    (as 2)
   33: )    EMPTY_TUPLE
   34: \x81 NEWOBJ
   35: \x94 MEMOIZE    (as 3)
   36: ]    EMPTY_LIST
   37: \x94 MEMOIZE    (as 4)
   38: (    MARK
   39: \x8c     SHORT_BINUNICODE 'cuda'
   45: \x94     MEMOIZE    (as 5)
   46: K        BININT1    17
   48: )        EMPTY_TUPLE
   49: e        APPENDS    (MARK at 38)
   50: b    BUILD
   51: .    STOP

@youkaichao
Copy link
Member Author

With all above optimization, the bytes to broadcast BlockMetaData can be reduced from 260 bytes to 107 bytes.

This benefit will become more significant when we apply the technique to prepare input related data stucture.

@rkooo567 rkooo567 self-assigned this May 12, 2024
Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

"improve the broadcast for prepare input in the same way." -> planning to do this in this PR? Also can you tell me the perf improvement from it?

Also can you update

this test to use tp > 1? I think it can verify the correctness of this change

metadata_list = []
tensor_list = []
for key, value in tensor_dict.items():
used_keys = keys or tensor_dict.keys()
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we assert keys == len(tensor_dict)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not necessary though. This current code is more flexible without assert.

@dataclasses.dataclass
class TensorMeta:
"""
This class is placed here to reduce the size of qualified name,
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what about we just create vllm/tensor_meta.py? Is this still long?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, vllm/tensor_meta.py will lead to vllm.tensor_meta.TensorMeta , longer than vllm.TensorMeta

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

does tensor_meta make a big difference? Feel like if it is just a little bit difference (like 2 digits microsecond), I prefer to avoid it...

tensor_list.append(value)
else:
metadata_list.append((key, value))
if keys is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit; why don't we just check it in line 213?

if keys is not None:
    metadata_list.append((key, value))
else:
    metadata_list.append(value)

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think control flow in the loop is more expensive (N control flow) than control flow outside of the loop (1 control flow).

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [12]: def control():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         if b:
    ...:             result.append((i, i))
    ...:         else:
    ...:             result.append(i)

In [16]: def copy():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         result.append((i, i))
    ...:     result = [value for key, value in result]
In [22]: timeit copy()
192 µs ± 686 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [23]: timeit control()
159 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you.


This class represents a dictionary of tensors with bounded metadata.
The upperbound of the buffer size is known a priori. Therefore, we can
pre-allocate a buffer for the metadata, and invoke only one collective
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is this correct because we are now broadcasting using cpu "tensor", we don't need to broadcast the object size (which is the implementation detail of broadcast_object_list)?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The key idea is not cpu "tensor", but we know the maximum size of the serialization, so we don't need to broadcast the length. This is indeed an implementation detail of broadcast_object_list.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the comment that it relies on that implementation detail?

vllm/distributed/communication_op.py Outdated Show resolved Hide resolved
vllm/distributed/communication_op.py Show resolved Hide resolved
dtypes). If `cls` is provided, we can know the length of the metadata
roughly and allocate a buffer for it, then broadcasting metadata requires
only one broadcast call. Otherwise, we need to broadcast the metadata
length first, then broadcast the metadata.
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add a simple example of how to use TensorDictWithBoundedMetadata in the docstring?

vllm/worker/worker.py Show resolved Hide resolved
@youkaichao
Copy link
Member Author

"improve the broadcast for prepare input in the same way."

It will require another PR.

Also can you tell me the perf improvement from it?

For broadcasting blocks to swap/copy, the benefit is:

Broadcast time (before): 0.38772106170654297ms
Broadcast time (after): 0.128173828125ms

I don't have an end-to-end benchmarking.

update test_swap

It requires quite a large modification to the test procedure (separate the test into distributed tests) . Meanwhile, the correctness is already checked in https://github.com/vllm-project/vllm/pull/4757/files#diff-cba46ef2b8ff23834781fa3b43794a3f19ffc6b4f1ec2353a8d13d1cdc2d0588R110 .

@youkaichao
Copy link
Member Author

@rkooo567 can you help take a look at https://buildkite.com/vllm/ci/builds/7258#018f732a-46ad-4e69-a35b-25f5200d0e19 ? The failure looks like a ray issue, the function cannot access the name List, although it is imported in the top.

@njhill njhill self-requested a review May 13, 2024 19:07
@njhill
Copy link
Collaborator

njhill commented May 13, 2024

@youkaichao it would be good to check whether there's non-negligible performance difference in end-to-end tests before introducing the additional complexity, it's not always easy to infer this from a microbenchmark. A simple before/after test generating a decent number of tokens with a TP deployment would be sufficient I think?

Do you know how much of the latency benefit comes from compressing the number of bytes with the new TensorMeta class vs eliminating one of the broadcasts?

The two broadcast_tensor_dicts (this one and prepare_input_tensors) are done immediately after each other, and it does not look like the second depends on the first, could we combine them?

Especially if they're combined, I'm wondering whether we can avoid the quite convoluted abstractions for what is just a single case.

Implementation-wise, instead of requiring the custom classes, what do you think about this:

  • Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes.
  • Use the first byte of the broadcast tensor to indicate the kind of message that follows, either
    1. a regular pickled object
    2. a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.

Then there's no need to maintain special classes. I also don't think there's any need to have special handling for the keys, we can just pass lists instead of dicts?

@njhill
Copy link
Collaborator

njhill commented May 13, 2024

@youkaichao another reason the above approach might be better - IIUC the get_example_metadata_list approach won't work if the size varies much at runtime (not sure whether that might be the case for prepare_input_tensors)?

@youkaichao
Copy link
Member Author

First of all, this PR is the first step for later optimization. Itself is a pure benefit because it reduces the broadcast from twice to once.

The followup for applying the optimization in prepare input needs to come after the refactor #4681 .

Within broadcast_tensor_dict, maintain a single static tensor buffer (instead of per class). Its initial size can be like 16 bytes.
Use the first byte of the broadcast tensor to indicate the kind of message that follows, either
a regular pickled object
a buffer resize, where the following 4 or 8 bytes are the encoded int size that the buffer should be increased to. After this, the broadcast is repeated.

This does not reduce the broadcast. It still requires two broadcast even if we don't have any tensor data to broadcast.

@rkooo567
Copy link
Collaborator

rkooo567 commented May 14, 2024

I think the nice benchmark to back up is;

  • how much is input broadcast overhead for e2e latency? -> In our internal bechmhark, we found this overhead is "very big". https://docs.google.com/spreadsheets/d/1GMyebF9XwlLJzpkpRxZrzUNcQTSibHlQ7zifldaDPtI/edit#gid=0, almost as big as model fwd at high tp.
  • I think there are 2 parts we can optimize. 1. reduce overhead of braodcast_object_list. 2. reduce the # of tensor broadcast (we do it per tensor). I think this tackles 1, and it'd be great to know how much is 1 in e2e broadcasting overhead (I believe @youkaichao already has the number).

Copy link
Collaborator

@rkooo567 rkooo567 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For PR, I think it LGTM.

prefer to avoid having TensorMeta inside init unless the perf diff is very big. I will approve it for now, but please resolve the discussion with @njhill before merging it!


This class represents a dictionary of tensors with bounded metadata.
The upperbound of the buffer size is known a priori. Therefore, we can
pre-allocate a buffer for the metadata, and invoke only one collective
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add this to the comment that it relies on that implementation detail?

tensor_list.append(value)
else:
metadata_list.append((key, value))
if keys is not None:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

In [12]: def control():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         if b:
    ...:             result.append((i, i))
    ...:         else:
    ...:             result.append(i)

In [16]: def copy():
    ...:     b = True
    ...:     result = []
    ...:     a = [i for i in range(3000)]
    ...:     for i in a:
    ...:         result.append((i, i))
    ...:     result = [value for key, value in result]
In [22]: timeit copy()
192 µs ± 686 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

In [23]: timeit control()
159 µs ± 487 ns per loop (mean ± std. dev. of 7 runs, 10,000 loops each)

Hmm actually I tried and it looks like control is faster. But I think the perf diff here is not very meaningful (it is premature optimization). I was asking because I thought it is easier to understand, but not strong opinion. I will leave it up to you.

@youkaichao
Copy link
Member Author

@njhill has a proposal to cache the max length of metadata based on callsite, I will wait and see how it works.

@njhill
Copy link
Collaborator

njhill commented May 16, 2024

@youkaichao I've opened #4844 to show the idea, PTAL!

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

3 participants