Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Core] Implement sharded state loader #4690

Merged
merged 22 commits into from May 16, 2024

Conversation

aurickq
Copy link
Contributor

@aurickq aurickq commented May 8, 2024

This PR implements a new model loader that directly loads the sharded states of each worker when using DistributedGPUExecutor. When using tensor parallelism, this avoids each worker reading the full checkpoint just to load a small shard of it. Our tests using Arctic (#4652) showed a 10x improvement in model loading speed from NVMe when using 8x tensor parallelism.

For quantization methods like DeepSpeed's FP_Quantize (also used in #4652) that quantize after loading, this PR allows easy creation of a quantized checkpoint that is directly loaded into each worker, further speeding up model loading.

This PR is separated out from #4652.

BEFORE SUBMITTING, PLEASE READ THE CHECKLIST BELOW AND FILL IN THE DESCRIPTION ABOVE


PR Checklist (Click to Expand)

Thank you for your contribution to vLLM! Before submitting the pull request, please ensure the PR meets the following criteria. This helps vLLM maintain the code quality and improve the efficiency of the review process.

PR Title and Classification

Only specific types of PRs will be reviewed. The PR title is prefixed appropriately to indicate the type of change. Please use one of the following:

  • [Bugfix] for bug fixes.
  • [CI/Build] for build or continuous integration improvements.
  • [Doc] for documentation fixes and improvements.
  • [Model] for adding a new model or improving an existing model. Model name should appear in the title.
  • [Frontend] For changes on the vLLM frontend (e.g., OpenAI API server, LLM class, etc.)
  • [Kernel] for changes affecting CUDA kernels or other compute kernels.
  • [Core] for changes in the core vLLM logic (e.g., LLMEngine, AsyncLLMEngine, Scheduler, etc.)
  • [Hardware][Vendor] for hardware-specific changes. Vendor name should appear in the prefix (e.g., [Hardware][AMD]).
  • [Misc] for PRs that do not fit the above categories. Please use this sparingly.

Note: If the PR spans more than one category, please include all relevant prefixes.

Code Quality

The PR need to meet the following code quality standards:

  • We adhere to Google Python style guide and Google C++ style guide.
  • Pass all linter checks. Please use format.sh to format your code.
  • The code need to be well-documented to ensure future contributors can easily understand the code.
  • Include sufficient tests to ensure the project to stay correct and robust. This includes both unit tests and integration tests.
  • Please add documentation to docs/source/ if the PR modifies the user-facing behaviors of vLLM. It helps vLLM user understand and utilize the new features or changes.

Notes for Large Changes

Please keep the changes as concise as possible. For major architectural changes (>500 LOC excluding kernel/data/config/test), we would expect a GitHub issue (RFC) discussing the technical design and justification. Otherwise, we will tag it with rfc-required and might not go through the PR.

What to Expect for the Reviews

The goal of the vLLM team is to be a transparent reviewing machine. We would like to make the review process transparent and efficient and make sure no contributor feel confused or frustrated. However, the vLLM team is small, so we need to prioritize some PRs over others. Here is what you can expect from the review process:

  • After the PR is submitted, the PR will be assigned to a reviewer. Every reviewer will pick up the PRs based on their expertise and availability.
  • After the PR is assigned, the reviewer will provide status update every 2-3 days. If the PR is not reviewed within 7 days, please feel free to ping the reviewer or the vLLM team.
  • After the review, the reviewer will put an action-required label on the PR if there are changes required. The contributor should address the comments and ping the reviewer to re-review the PR.
  • Please respond to all comments within a reasonable time frame. If a comment isn't clear or you disagree with a suggestion, feel free to ask for clarification or discuss the suggestion.

Thank You

Finally, thank you for taking the time to read these guidelines and for your interest in contributing to vLLM. Your contributions make vLLM a great tool for everyone!

@Yard1
Copy link
Collaborator

Yard1 commented May 9, 2024

We should add a test to ensure it's working correctly (loaded weights and outputs are the same as with the default loader)

Comment on lines 44 to 62
def main(args):
engine_args = EngineArgs.from_cli_args(args)
model_path = engine_args.model
if not Path(model_path).is_dir():
raise ValueError("model path must be a local directory")
# Create LLM instance from arguments
llm = LLM(**dataclasses.asdict(engine_args))
# Prepare output directory
Path(args.output).mkdir(exist_ok=True)
# Dump worker states to output directory
model_executor = llm.llm_engine.model_executor
model_executor.save_sharded_state(path=args.output,
pattern=args.pattern,
max_size=5 * 1024**3)
# Copy metadata files to output directory
for file in os.listdir(model_path):
if not any(
file.endswith(ext) for ext in (".bin", ".pt", ".safetensors")):
shutil.copy(f"{model_path}/{file}", args.output)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

shouldn't this set the config.json or quant_config.json next to the model weights to inform vLLM loading the model what type of quantization the model checkpoint is in?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Currently, it's just copying the config.json and quant_config.json from the input checkpoint that's being converted, which works for the use cases we've tested. Actually, I am not sure if it's correct to override these configs (or add a quant_config.json where it wasn't there previously) because then the config may mismatch the final loaded states?

@aurickq
Copy link
Contributor Author

aurickq commented May 9, 2024

We should add a test to ensure it's working correctly (loaded weights and outputs are the same as with the default loader)

Sure, that makes sense. What's the spec of the test runner machine we should target? OK to assume cuda device is present?

@Yard1
Copy link
Collaborator

Yard1 commented May 9, 2024

@aurickq IIRC we are using L4 GPUs in CI, @simon-mo to confirm

@aurickq
Copy link
Contributor Author

aurickq commented May 9, 2024

@Yard1 added test, please take a look

examples/save_sharded_state.py Outdated Show resolved Hide resolved
examples/save_sharded_state.py Outdated Show resolved Hide resolved
tests/test_sharded_state_loader.py Outdated Show resolved Hide resolved
tests/test_sharded_state_loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
@simon-mo
Copy link
Collaborator

One question I have is that can this be implemented using safetensor's partial read? safetensors have all the metadata in headers so you can access the tensors partially

@aurickq
Copy link
Contributor Author

aurickq commented May 10, 2024

One question I have is that can this be implemented using safetensor's partial read? safetensors have all the metadata in headers so you can access the tensors partially

Conceptually, I think so. Though for larger models like Arctic we prefer this implementation for a few reasons:

  1. With the current implementation and multi-node inference, we may download only the required files to each node.
  2. The deepspeedfp quantizer will reshape the parameters and may change its number of elements, which makes it a bit tricky to calculate the right slice to load (need to understand internals of the quantizer to slice the quantized weights).

@njhill
Copy link
Collaborator

njhill commented May 10, 2024

@aurickq curious how this relates to #3729?

@aurickq
Copy link
Contributor Author

aurickq commented May 10, 2024

@aurickq curious how this relates to #3729?

Ah, I hadn't seen that PR, thanks for bringing it up! From what I understand, both PRs address the problem of model loading speed, particularly for larger models, but the approaches seem pretty different. #3729 modifies the existing model loading path so each worker loads a different set of tensors, then broadcasts the shards to each other. This PR loads the model once using the default path and dumps each worker state to disk, then subsequent model loads can direct read these states, skipping the default model loading path altogether.

Some first thoughts on the pros/cons:

Copy link
Collaborator

@Yard1 Yard1 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks, two last comments

vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
vllm/model_executor/model_loader/loader.py Outdated Show resolved Hide resolved
@WoosukKwon WoosukKwon self-requested a review May 10, 2024 23:32
Copy link
Collaborator

@WoosukKwon WoosukKwon left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! The code looks clean and mergable. Thanks for addressing the comments from @Yard1.

@zhisbug zhisbug merged commit 30e7543 into vllm-project:main May 16, 2024
41 of 48 checks passed
robertgshaw2-neuralmagic pushed a commit to neuralmagic/nm-vllm that referenced this pull request May 19, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
dtrifiro pushed a commit to dtrifiro/vllm that referenced this pull request May 21, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
tybalex pushed a commit to tybalex/vllm-function-call that referenced this pull request May 25, 2024
Co-authored-by: Woosuk Kwon <woosuk.kwon@berkeley.edu>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

9 participants