Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for compressed-tensors #159

Open
wants to merge 18 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 16 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions requirements-common.txt
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ psutil
sentencepiece # Required for LLaMA tokenizer.
numpy
requests
compressed-tensors
py-cpuinfo
transformers >= 4.39.1 # Required for StarCoder2 & Llava.
fastapi
Expand Down
70 changes: 35 additions & 35 deletions tests/engine/test_stop_strings.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@

from vllm import CompletionOutput, LLMEngine, SamplingParams

MODEL = "meta-llama/llama-2-7b-hf"
MODEL = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is an upstream test AFAIK, so we should avoid landing changes here

MAX_TOKENS = 200


Expand All @@ -16,66 +16,66 @@ def vllm_model(vllm_runner):
@pytest.mark.skip_global_cleanup
def test_stop_basic(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["."],
stop=["in"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=".")
expected_output="\nVLLM is a company that specializes ",
expected_reason="in")

_test_stopping(vllm_model.model.llm_engine,
stop=["."],
stop=["in"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization.",
expected_reason=".")
expected_output="\nVLLM is a company that specializes in",
expected_reason="in")


@pytest.mark.skip_global_cleanup
def test_stop_multi_tokens(vllm_model):
_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer organization. We are a ",
expected_reason="group of peo")
_test_stopping(vllm_model.model.llm_engine,
stop=["providing virtual", "short"],
include_in_output=False,
expected_output="\nVLLM is a company that specializes in ",
expected_reason="providing virtual")

_test_stopping(
vllm_model.model.llm_engine,
stop=["group of peo", "short"],
include_in_output=True,
expected_output=
"VLLM is a 100% volunteer organization. We are a group of peo",
expected_reason="group of peo")
_test_stopping(vllm_model.model.llm_engine,
stop=["providing virtual", "short"],
include_in_output=True,
expected_output=
"\nVLLM is a company that specializes in providing virtual",
expected_reason="providing virtual")


@pytest.mark.skip_global_cleanup
def test_stop_partial_token(vllm_model):
_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
stop=["izes"],
include_in_output=False,
expected_output="VLLM is a 100% volunteer or",
expected_reason="gani")
expected_output="\nVLLM is a company that special",
expected_reason="izes")

_test_stopping(vllm_model.model.llm_engine,
stop=["gani"],
stop=["izes"],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organi",
expected_reason="gani")
expected_output="\nVLLM is a company that specializes",
expected_reason="izes")


@pytest.mark.skip_global_cleanup
def test_stop_token_id(vllm_model):
# token id 13013 => " organization"
# token id 6901 => "virtual"

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
include_in_output=False,
expected_output="VLLM is a 100% volunteer",
expected_reason=13013)
_test_stopping(
vllm_model.model.llm_engine,
stop_token_ids=[6901],
include_in_output=False,
expected_output="\nVLLM is a company that specializes in providing",
expected_reason=6901)

_test_stopping(vllm_model.model.llm_engine,
stop_token_ids=[13013],
stop_token_ids=[6901],
include_in_output=True,
expected_output="VLLM is a 100% volunteer organization",
expected_reason=13013)
expected_output=
"\nVLLM is a company that specializes in providing virtual",
expected_reason=6901)


def _test_stopping(llm_engine: LLMEngine,
Expand Down
61 changes: 61 additions & 0 deletions tests/models/test_load_compressed_tensors_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""
Compare the outputs from identical models:
- one that is loaded from uncompressed safetensors
- one that is loaded form `compressed-tensors`.
The expectation is for the inference result in same behavior
"""
from typing import Tuple

import pytest
from compare_utils import check_logprobs_close

MODEL_MAX_LEN = 1024

# pair of same models with weight safed as
# compressed-tensors (compressed)
# and ordinary safetensors (uncompressed)
MODELS = [("neuralmagic/llama2.c-stories110M-pruned50",
"nm-testing/llama2.c-stories110M-pruned50-compressed-tensors")]


@pytest.mark.parametrize("model_pair", MODELS)
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@pytest.mark.parametrize("max_tokens", [32])
@pytest.mark.parametrize("num_logprobs", [3])
def test_models(
vllm_runner_nm,
example_prompts,
model_pair: Tuple[str, str],
dtype: str,
max_tokens: int,
num_logprobs: int,
) -> None:

model_uncompressed, model_compressed = model_pair

vllm_model_0 = vllm_runner_nm(model_uncompressed,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_0 = vllm_model_0.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

del vllm_model_0

vllm_model_1 = vllm_runner_nm(model_compressed,
dtype=dtype,
max_model_len=MODEL_MAX_LEN)

vllm_outputs_1 = vllm_model_1.generate_greedy_logprobs(
example_prompts, max_tokens, num_logprobs)

del vllm_model_1

# make sure that both weight types result in same
# models i.e. same outputs
check_logprobs_close(
outputs_0_lst=vllm_outputs_0,
outputs_1_lst=vllm_outputs_1,
name_0="vllm_model_from_uncompressed_weights",
name_1="vllm_model_from_compressed_weights",
)
78 changes: 60 additions & 18 deletions vllm/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,13 +2,15 @@
import json
import os
from dataclasses import dataclass, fields
from typing import TYPE_CHECKING, ClassVar, Optional, Union
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union

import torch
from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig
from packaging.version import Version
from transformers import PretrainedConfig

from vllm.logger import init_logger
from vllm.model_executor.layers.sparsity import SparsityStructures
from vllm.transformers_utils.config import get_config, get_hf_text_config
from vllm.utils import (get_cpu_memory, get_nvcc_cuda_version, is_cpu, is_hip,
is_neuron)
Expand Down Expand Up @@ -182,29 +184,69 @@ def _verify_tokenizer_mode(self) -> None:

# UPSTREAM SYNC: keep sparsity
def _verify_sparsity(self) -> None:
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"]
supported_sparsity = set(SparsityStructures.keys())

hf_sparsity_config = getattr(self.hf_config, SPARSITY_CONFIG_NAME,
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
None)

if hf_sparsity_config is not None:
sparsity_structure = self._sparsity_structure_from_config(
hf_sparsity_config, dtype=self.dtype)
if self.sparsity is not None:
logger.info("Overriding the sparsity structure "
"inferred from the config: "
f"{sparsity_structure} with: {self.sparsity}")
self.sparsity = self.sparsity or sparsity_structure
if (self.sparsity not in supported_sparsity) and \
(self.sparsity is not None):
raise ValueError(
f"Unknown sparsity_structure: {self.sparsity}. Must "
f"be one of {supported_sparsity}. Running the models "
"without sparse kernels.")
Comment on lines +151 to +154
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This says "Running the models without sparse kernels." but is raising an error - I think this should just warn and continue with unstructured sparse kernels?


if self.quantization is not None and self.sparsity is not None:
raise ValueError("Both sparsity and quantization detected. Only "
"one or the other is supported at a time.")

if (self.sparsity is not None
and self.sparsity not in supported_sparsity):
raise ValueError(f"Unknown sparse method: {self.sparsity}. Must "
f"be one of {supported_sparsity}.")
@staticmethod
def _sparsity_structure_from_config(sparsity_config: Dict[str, Any],
dtype: torch.dtype) -> str:
"""
Translate from the sparsity_config to an appropriate sparsity structure.

:param sparsity_config: A dictionary specifying the sparsity config
dbogunowicz marked this conversation as resolved.
Show resolved Hide resolved
:param dtype: The dtype of the model in question
:return The appropriate sparsity structure as string
"""
supported_sparsity_dtypes = {torch.float16, torch.bfloat16}

# check the validity of sparsity_config
potentially_missing_keys = set(sparsity_config.keys()).difference(
CompressionConfig.model_fields.keys())
if potentially_missing_keys:
raise ValueError("The detected sparsity_config is "
f"missing keys: {potentially_missing_keys}")

# check for valid dtype
if dtype not in supported_sparsity_dtypes:
logger.warning(
f"Sparsity is only supported for {supported_sparsity_dtypes}"
f"dtypes, not {dtype}. "
"Running the models without sparse kernels.")
return None

hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None)
if hf_sparsity_config is not None:
hf_sparsity_method = str(
hf_sparsity_config["sparse_method"]).lower()
if self.sparsity is None:
self.sparsity = hf_sparsity_method
elif self.sparsity != hf_sparsity_method:
raise ValueError(
"Sparsity method specified in the model config "
f"({hf_sparsity_method}) does not match the sparsity "
f"method specified in the `sparsity` argument "
f"({self.sparsity}).")
# choose the sparsity structure based on the sparsity config
if sparsity_config["sparsity_structure"] in {"unstructured", "0:0"}:
return SparsityStructures['sparse_w16a16'].name

elif sparsity_config["sparsity_structure"] == "2:4":
return SparsityStructures['semi_structured_sparse_w16a16'].name

# if the sparsity config is not recognized, return None
logger.warning("The valid sparsity structure cannot be inferred from "
"the valid sparsity config:\n{sparsity_config}"
"\nRunning the models without sparse kernels.")
return None
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not just make the exception for 2:4 and use unstructured kernels for all other cases?


def _verify_quantization(self) -> None:
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"]
Expand Down
28 changes: 20 additions & 8 deletions vllm/model_executor/layers/sparsity/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import importlib.util
from collections import namedtuple
from typing import Type

is_magic_wand_available = importlib.util.find_spec("magic_wand") is not None
Expand All @@ -14,16 +15,27 @@
from vllm.model_executor.layers.sparsity.sparse_w16a16 import ( # noqa: E402
SparseW16A16Config)

_SPARSITY_CONFIG_REGISTRY = {
"sparse_w16a16": SparseW16A16Config,
"semi_structured_sparse_w16a16": SemiStructuredSparseW16A16Config,
}
# UPSTREAM SYNC: where we keep the sparsity configs
sparsity_structure_meta = namedtuple('SparsityStructure', ['name', 'config'])

SparsityStructures = dict(
sparse_w16a16=sparsity_structure_meta("sparse_w16a16", SparseW16A16Config),
semi_structured_sparse_w16a16=sparsity_structure_meta(
"semi_structured_sparse_w16a16", SemiStructuredSparseW16A16Config),
)
Comment on lines -17 to +25
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We want to keep to "registry" structure to match up with quantization, see the quantization/init.py file

QUANTIZATION_METHODS = {
    "aqlm": AQLMConfig,
    "awq": AWQConfig,
    "fp8": Fp8Config,
    "gptq": GPTQConfig,
    "squeezellm": SqueezeLLMConfig,
    "marlin": MarlinConfig,
}

So could we keep it as SPARSITY_METHODS and a raw dict?


def get_sparsity_config(sparsity: str) -> Type[SparsityConfig]:
if sparsity not in _SPARSITY_CONFIG_REGISTRY:
raise ValueError(f"Invalid sparsity method: {sparsity}")
return _SPARSITY_CONFIG_REGISTRY[sparsity]

# UPSTREAM SYNC: needed for sparsity
def get_sparsity_config(
model_config: "ModelConfig") -> Type[SparsityConfig]: # noqa: F821
# fetch the sparsity config from the model config
sparsity = model_config.sparsity
if sparsity not in SparsityStructures:
raise ValueError(
f"Invalid sparsity method: {sparsity}. "
f"Available sparsity methods: {list(SparsityStructures.keys())}")
sparsity_cls = SparsityStructures[sparsity].config
return sparsity_cls()


__all__ = [
Expand Down
4 changes: 2 additions & 2 deletions vllm/model_executor/model_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,10 +6,10 @@
import torch.nn as nn

from vllm.config import DeviceConfig, ModelConfig
from vllm.model_executor.layers.sparsity import get_sparsity_config
from vllm.model_executor.models import ModelRegistry
from vllm.model_executor.models.llava import LlavaForConditionalGeneration
from vllm.model_executor.weight_utils import (get_quant_config,
get_sparse_config,
initialize_dummy_weights)

_VISION_MODEL_CLASSES = [
Expand Down Expand Up @@ -75,7 +75,7 @@ def get_model(model_config: ModelConfig, device_config: DeviceConfig,
linear_method = quant_config.get_linear_method()
# UPSTREAM SYNC: needed to support sparsity
if model_config.sparsity is not None:
sparse_config = get_sparse_config(model_config)
sparse_config = get_sparsity_config(model_config)
capability = torch.cuda.get_device_capability()
capability = capability[0] * 10 + capability[1]
if capability < sparse_config.get_min_capability():
Expand Down