Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core][Distributed] refactor custom allreduce to support multiple tp groups #4754

Merged
merged 26 commits into from May 13, 2024

Conversation

youkaichao
Copy link
Member

@youkaichao youkaichao commented May 10, 2024

Previously custom allreduce is attached to a module, and only bound to the world group.

With this PR, it is bound correctly to the tp group.

  • remove import inside function (need to remove default group None option)

@WoosukKwon WoosukKwon self-assigned this May 12, 2024
@WoosukKwon
Copy link
Collaborator

@hanzhi713 Could you please also take a look if you have time?

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@youkaichao Thanks for submitting the PR and many thanks for walking through the PR offline. Please check out my comments, most of which are style issues.

tests/distributed/test_comm_ops.py Outdated Show resolved Hide resolved
tests/distributed/test_custom_all_reduce.py Outdated Show resolved Hide resolved
from vllm.distributed.device_communicators.pynccl import PyNcclCommunicator
_TP_PYNCCL_COMMUNICATOR = PyNcclCommunicator(
group=_TP_CPU_GROUP,
device=_LOCAL_RANK,
)

# Initialize a custom fast all-reduce implementation.
if _ENABLE_CUSTOM_ALL_REDUCE:
from vllm.distributed.device_communicators.custom_all_reduce import (
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why do we need lazy import here?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The circular import:

vllm/distributed/__init__.py imports parallel_state and communication_op. If parallel_state imports from vllm.distributed.device_communicators.custom_all_reduce in the top level, then this is a circular import because custom_all_reduce imports get_tensor_model_parallel_cpu_group from parallel_state.

Therefore, either parallel_state or custom_all_reduce has to use lazy import to break the circular import.

I use lazy import in parallel_state, to be consistent with how we import pynccl.

vllm/distributed/communication_op.py Outdated Show resolved Hide resolved
Comment on lines 38 to 39
for sz in test_sizes:
for dtype in [torch.float32, torch.float16, torch.bfloat16]:
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

nit: Maybe we can use dtype as an input parameter so that the dtypes can be tested separately?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Quite difficult. The input of test_custom_allreduce, is coupled with the input of multi_process_tensor_parallel, which is used elsewhere. In other words, we cannot modify the input parameter inside just tests/distributed/test_custom_all_reduce.py 🤣

vllm/distributed/device_communicators/custom_all_reduce.py Outdated Show resolved Hide resolved
vllm/distributed/device_communicators/custom_all_reduce.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon removed their assignment May 12, 2024
@bingfengyiren
Copy link

bingfengyiren commented May 12, 2024 via email

Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Thanks for addressing my comments!

@youkaichao youkaichao enabled auto-merge (squash) May 12, 2024 23:26
@hanzhi713
Copy link
Contributor

LGTM

auto-merge was automatically disabled May 13, 2024 00:46

Base branch was modified

@WoosukKwon WoosukKwon merged commit 702bee4 into vllm-project:main May 13, 2024
53 of 55 checks passed
@youkaichao youkaichao deleted the ca_refactor branch May 13, 2024 00:49
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants