Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for compressed-tensors #159

Open
wants to merge 18 commits into
base: main
Choose a base branch
from

Conversation

dbogunowicz
Copy link

The goal of this PR is to support the weight loading from the compressed safetensor representation.
The compressed safetensor representation has been introduced by Neural Magic, and implemented by @Satrat .

Copy link
Collaborator

@robertgshaw2-neuralmagic robertgshaw2-neuralmagic left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you add an error message for the cases where use_safetensors=False and we detect compressor?

vllm/model_executor/weight_utils.py Outdated Show resolved Hide resolved
vllm/model_executor/weight_utils.py Outdated Show resolved Hide resolved
tests/model_executor/test_weight_utils.py Outdated Show resolved Hide resolved
@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Apr 3, 2024

Really like how the .decompress() function looks. Makes the interface really clean

@dbogunowicz one additional thing that needs to be done:

Right now, the user has to specify that it should use the sparse kernels manually:

from vllm import LLM

# loads as sparse
model = LLM("/path/to/sparse/model", sparsity="sparse_w16a16")

# loads as dense
model = LLM("/path/to/sparse/model")

Ideally, we should automatically detect if the model is sparse based on the config and load it if so:

from vllm import LLM

# loads as sparse
model = LLM("/path/to/sparse/model")

This is how things work for quantization. I left a placeholder when I originally integrated the sparse kernels for this logic here.

Can you add this?

@robertgshaw2-neuralmagic
Copy link
Collaborator

robertgshaw2-neuralmagic commented Apr 3, 2024

Finally, please add some end-to-end testing, which loads the compressed model and runs inference

I would suggest the following format:

  • Take an existing small sparse model (neuralmagic/llama2.c-stories110M-pruned50)
  • Save a compressed version, push this model up to nm-testing
  • Use the tests/models/test_model_logprobs.py format to compare the outputs of the existing uncompressed version to the compressed version

requirements.txt Outdated Show resolved Hide resolved
@dbogunowicz dbogunowicz changed the title Support for compressed safetensor weights Support for compressed-tensors Apr 23, 2024
# pair of same models with compressed and ordinary safetensors
MODELS = [
("neuralmagic/llama2.c-stories110M-pruned50", # uncompressed
"dtransposed/llama2.c-stories110M-pruned50-compressed-tensors") # compressed
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will move this model to neuralmagic repo before landing

vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated
"inferred from the config: "
f"{sparsity_structure} with: {self.sparsity}")
self.sparsity = self.sparsity or sparsity_structure
if self.sparsity not in supported_sparsity and self.sparsity is not None: # noqa E501
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the # noqa E501 necessary? Could the line just get split?

vllm/config.py Outdated Show resolved Hide resolved
vllm/config.py Outdated
Comment on lines 239 to 240
"Sparsity is only supported for float16 and bfloat16 "
"dtypes. Running the models without sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the dtype in the message?

vllm/config.py Outdated
Comment on lines 251 to 253
logger.warning("The valid sparsity structure cannot be inferred from "
"the valid sparsity config. Running the models without "
"sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the sparsity_config to the message?

dbogunowicz and others added 3 commits April 24, 2024 16:27
Co-authored-by: Michael Goin <michael@neuralmagic.com>
Co-authored-by: Michael Goin <michael@neuralmagic.com>
vllm/config.py Outdated
@@ -238,21 +239,21 @@ def _sparsity_structure_from_config(
# check for valid dtype
if dtype not in supported_sparsity_dtypes:
logger.warning(
"Sparsity is only supported for float16 and bfloat16 "
f"Sparsity is only supported for {supported_sparsity_dtypes}"
"dtypes. Running the models without sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant the current dtype, but supported dtypes are good too!

Suggested change
"dtypes. Running the models without sparse kernels.")
f"dtypes, not {dtype}. Running the models without sparse kernels.")

@@ -20,7 +20,7 @@


@pytest.mark.parametrize("model_pair", MODELS)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("dtype", ["float16"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove bfloat16 here?

vllm/config.py Outdated
"the valid sparsity config. Running the models without "
"sparse kernels.")
"the valid sparsity config:\n{sparsity_config}"
"\n Running the models without sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
"\n Running the models without sparse kernels.")
"\nRunning the models without sparse kernels.")

@@ -4,7 +4,7 @@

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an upstream test AFAIK, so we should avoid landing changes here

vllm/config.py Show resolved Hide resolved
Comment on lines +202 to +205
raise ValueError(
f"Unknown sparsity_structure: {self.sparsity}. Must "
f"be one of {supported_sparsity}. Running the models "
"without sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says "Running the models without sparse kernels." but is raising an error - I think this should just warn and continue with unstructured sparse kernels?

vllm/config.py Outdated
Comment on lines 238 to 249
# choose the sparsity structure based on the sparsity config
if sparsity_config["sparsity_structure"] in {"unstructured", "0:0"}:
return SparsityStructures['sparse_w16a16'].name

elif sparsity_config["sparsity_structure"] == "2:4":
return SparsityStructures['semi_structured_sparse_w16a16'].name

# if the sparsity config is not recognized, return None
logger.warning("The valid sparsity structure cannot be inferred from "
"the valid sparsity config:\n{sparsity_config}"
"\nRunning the models without sparse kernels.")
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make the exception for 2:4 and use unstructured kernels for all other cases?

Comment on lines -17 to +25
_SPARSITY_CONFIG_REGISTRY = {
"sparse_w16a16": SparseW16A16Config,
"semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config,
}
# UPSTREAM SYNC: where we keep the sparsity configs
sparsity_structure_meta = namedtuple('SparsityStructure', ['name', 'config'])

SparsityStructures = dict(
sparse_w16a16=sparsity_structure_meta("sparse_w16a16", SparseW16A16Config),
semi_structured_sparse_w16a16=sparsity_structure_meta(
"semi_structured_sparse_w16a16", SemiStructuredSparseW16A16Config),
)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep to "registry" structure to match up with quantization, see the quantization/init.py file

QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "awq": AWQConfig,
    "fp8": Fp8Config,
    "gptq": GPTQConfig,
    "squeezellm": SqueezeLLMConfig,
    "marlin": MarlinConfig,
}

So could we keep it as SPARSITY_METHODS and a raw dict?

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

None yet

4 participants