Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for compressed-tensors #159

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-build.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,4 +4,5 @@ ninja
packaging
setuptools>=49.4.0
torch==2.2.1
compressed-tensors
wheel
40 changes: 20 additions & 20 deletions tests/models/test_load_compressed_tensors_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,20 +4,23 @@
The expectation is for the inference result in same
behavior
"""
import pytest
from typing import Tuple

import pytest
from compare_utils import check_logprobs_close

MODEL_MAX_LEN = 1024

# pair of same models with compressed and ordinary safetensors
MODELS = [
("neuralmagic/llama2.c-stories110M-pruned50", # uncompressed
"dtransposed/llama2.c-stories110M-pruned50-compressed-tensors") # compressed
]
MODELS = [(
"neuralmagic/llama2.c-stories110M-pruned50", # uncompressed
"dtransposed/llama2.c-stories110M-pruned50-compressed-tensors"
) # compressed
]


@pytest.mark.parametrize("model_pair", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
Expand All @@ -31,24 +34,21 @@ def test_models(

model_uncompressed, model_compressed = model_pair


vllm_model_0 = vllm_runner_nm(model_uncompressed,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_0 = vllm_model_0.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)
vllm_outputs_0 = vllm_model_0.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

del vllm_model_0

vllm_model_1 = vllm_runner_nm(model_compressed ,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_1 = vllm_model_1.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

vllm_model_1 = vllm_runner_nm(model_compressed,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_1 = vllm_model_1.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

del vllm_model_1

Expand Down
76 changes: 59 additions & 17 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,11 @@
import json
import os
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from enum import Enum
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union

import torch
from compressed_tensors import SPARSITY_CONFIG_NAME
from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig
from packaging.version import Version
from transformers import PretrainedConfig

Expand All @@ -22,6 +23,11 @@
_GB = 1 << 30


class SparsityStructure(Enum):
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
sparse_w16a16 = "sparse_w16a16"
semi_structured_sparse_w16a16 = "semi_structured_sparse_w16a16"


class ModelConfig:
"""Configuration for the model.

Expand Down Expand Up @@ -183,34 +189,70 @@ def _verify_tokenizer_mode(self) -> None:

# UPSTREAM SYNC: keep sparsity
def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]
supported_sparsity = {
sparsity_structure.value
for sparsity_structure in SparsityStructure
}

hf_sparsity_config = getattr(self.hf_config, SPARSITY_CONFIG_NAME,
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
None)
if hf_sparsity_config is not None:
if "sparsity_structure" not in hf_sparsity_config:
raise ValueError(
"Detected HuggingFace sparsity config for the model, "
"but it does not have a mandatory "
"attribute: `sparsity_structure` ")
hf_sparsity_structure = str(
hf_sparsity_config["sparsity_structure"]).lower()
sparsity_structure = self._sparsity_structure_from_config(
hf_sparsity_config, dtype=self.dtype)
if self.sparsity is not None:
logger.info(
"Overriding the sparsity structure from the config: "
f"{hf_sparsity_structure} with: {self.sparsity}")
self.sparsity = self.sparsity or hf_sparsity_structure
if self.sparsity not in supported_sparsity:
logger.warning(
logger.info("Overriding the sparsity structure "
"inferred from the config: "
f"{sparsity_structure} with: {self.sparsity}")
self.sparsity = self.sparsity or sparsity_structure
if self.sparsity not in supported_sparsity and self.sparsity is not None: # noqa E501
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is the # noqa E501 necessary? Could the line just get split?

raise ValueError(
f"Unknown sparsity_structure: {self.sparsity}. Must "
f"be one of {supported_sparsity}. Running the models "
"without sparse kernels.")
self.sparsity = None

if self.quantization is not None and self.sparsity is not None:
raise ValueError("Both sparsity and quantization detected. Only "
"one or the other is supported at a time.")

@staticmethod
def _sparsity_structure_from_config(
sparsity_config: Dict[str, Any],
dtype: torch.dtype) -> SparsityStructure:
"""
Translate from the sparsity_config to an appropriate sparsity structure.

:param sparsity_config: A dictionary specifying the sparsity config
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
:return The appropriate sparsity structure as string
"""
supported_sparsity_dtypes = {torch.float16, torch.bfloat16}

# check the validy of sparsity_config
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
potentially_missing_keys = set(sparsity_config.keys()).difference(
CompressionConfig.model_fields.keys())
if potentially_missing_keys:
raise ValueError("The detected sparsity_config is "
f"missing keys: {potentially_missing_keys}")

# check for valid dtype
if dtype not in supported_sparsity_dtypes:
logger.warning(
"Sparsity is only supported for float16 and bfloat16 "
"dtypes. Running the models without sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the dtype in the message?

Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I actually meant the current dtype, but supported dtypes are good too!

Suggested change
"dtypes. Running the models without sparse kernels.")
f"dtypes, not {dtype}. Running the models without sparse kernels.")

return None

# choose the sparsity structure based on the sparsity config
if sparsity_config["sparsity_structure"] in {"unstructured", "0:0"}:
return SparsityStructure.sparse_w16a16

elif sparsity_config["sparsity_structure"] == "2:4":
return SparsityStructure.semi_structured_sparse_w16a16

# if the sparsity config is not recognized, return None
logger.warning("The valid sparsity structure cannot be inferred from "
"the valid sparsity config. Running the models without "
"sparse kernels.")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

could you add the sparsity_config to the message?

return None

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq", "marlin"]
Expand Down