Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for compressed-tensors #159

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 8 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
61 changes: 61 additions & 0 deletions tests/models/test_load_compressed_tensors_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Compare the outputs from identical models:
- one that is loaded from uncompressed safetensors
- one that is loaded form `compressed-tensors`.
The expectation is for the inference result in same
behavior
"""
import pytest
from typing import Tuple
from compare_utils import check_logprobs_close

MODEL_MAX_LEN = 1024

# pair of same models with compressed and ordinary safetensors
MODELS = [
("neuralmagic/llama2.c-stories110M-pruned50", # uncompressed
"dtransposed/llama2.c-stories110M-pruned50-compressed-tensors") # compressed
Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

will move this model to neuralmagic repo before landing

]

@pytest.mark.parametrize("model_pair", MODELS)
@pytest.mark.parametrize("dtype", ["float16"])
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why remove bfloat16 here?

@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner_nm,
example_prompts,
model_pair: Tuple[str, str],
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:

model_uncompressed, model_compressed = model_pair


vllm_model_0 = vllm_runner_nm(model_uncompressed,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_0 = vllm_model_0.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

del vllm_model_0

vllm_model_1 = vllm_runner_nm(model_compressed ,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_1 = vllm_model_1.generate_greedy_logprobs(example_prompts,
max_tokens,
num_logprobs)

del vllm_model_1

# loop through the prompts
check_logprobs_close(
outputs_0_lst=vllm_outputs_0,
outputs_1_lst=vllm_outputs_1,
name_0="vllm_model_from_uncompressed_weights",
name_1="vllm_model_from_compressed_weights",
)
41 changes: 23 additions & 18 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
from typing import TYPE_CHECKING, ClassVar, Optional, Union

import torch
from compressed_tensors import SPARSITY_CONFIG_NAME
from packaging.version import Version
from transformers import PretrainedConfig

Expand Down Expand Up @@ -184,28 +185,32 @@ def _verify_tokenizer_mode(self) -> None:
def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]

hf_sparsity_config = getattr(self.hf_config, SPARSITY_CONFIG_NAME,
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
None)
if hf_sparsity_config is not None:
if "sparsity_structure" not in hf_sparsity_config:
raise ValueError(
"Detected HuggingFace sparsity config for the model, "
"but it does not have a mandatory "
"attribute: `sparsity_structure` ")
hf_sparsity_structure = str(
hf_sparsity_config["sparsity_structure"]).lower()
if self.sparsity is not None:
logger.info(
"Overriding the sparsity structure from the config: "
f"{hf_sparsity_structure} with: {self.sparsity}")
self.sparsity = self.sparsity or hf_sparsity_structure
if self.sparsity not in supported_sparsity:
logger.warning(
f"Unknown sparsity_structure: {self.sparsity}. Must "
f"be one of {supported_sparsity}. Running the models "
"without sparse kernels.")
self.sparsity = None

if self.quantization is not None and self.sparsity is not None:
raise ValueError("Both sparsity and quantization detected. Only "
"one or the other is supported at a time.")

if (self.sparsity is not None
and self.sparsity not in supported_sparsity):
raise ValueError(f"Unknown sparse method: {self.sparsity}. Must "
f"be one of {supported_sparsity}.")

hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None)
if hf_sparsity_config is not None:
hf_sparsity_method = str(
hf_sparsity_config["sparse_method"]).lower()
if self.sparsity is None:
self.sparsity = hf_sparsity_method
elif self.sparsity != hf_sparsity_method:
raise ValueError(
"Sparsity method specified in the model config "
f"({hf_sparsity_method}) does not match the sparsity "
f"method specified in the `sparsity` argument "
f"({self.sparsity}).")

def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
rocm_not_supported_quantization = ["awq", "marlin"]
Expand Down
24 changes: 19 additions & 5 deletions vllm/model_executor/weight_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
import huggingface_hub.constants
import numpy as np
import torch
from compressed_tensors import infer_compressor_from_model_config
from compressed_tensors.config import CompressionFormat
from huggingface_hub import HfFileSystem, snapshot_download
from safetensors.torch import load_file, safe_open, save_file
from tqdm.auto import tqdm
Expand Down Expand Up @@ -293,11 +295,23 @@ def hf_model_weights_iterator(
param = np.load(f)
yield name, torch.from_numpy(param)
elif use_safetensors:
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
# UPSTREAM SYNC: needed for loading compressed tensors
# (see neural-magic/compressed-tensors repository)
compressor = infer_compressor_from_model_config(hf_folder)
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
compression_format = compressor.config.format if compressor else None
if compressor is None or compression_format == CompressionFormat.dense_sparsity.value: # noqa E501
for st_file in hf_weights_files:
with safe_open(st_file, framework="pt") as f:
for name in f.keys(): # noqa: SIM118
param = f.get_tensor(name)
yield name, param
else:
# a non-trivial (not dense) compressor inferred,
# the models weights are compressed (sparse), so
# they need decompressing before loading
for name, param in compressor.decompress(hf_folder):
yield name, param

else:
for bin_file in hf_weights_files:
state = torch.load(bin_file, map_location="cpu")
Expand Down