-
Notifications
You must be signed in to change notification settings - Fork 8
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Support for compressed-tensors
#159
base: main
Are you sure you want to change the base?
Changes from 13 commits
39fcb11
c97fde6
7c21857
94aaf9b
6567b48
aba6834
eed5123
37b22ec
3bdca69
3af0404
59b1577
54d504c
ac66cd6
118d82a
185ff9b
0c6a2d5
1fda4fc
13a1f5b
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -4,4 +4,5 @@ ninja | |
packaging | ||
setuptools>=49.4.0 | ||
torch==2.2.1 | ||
compressed-tensors | ||
wheel |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,61 @@ | ||
"""Compare the outputs from identical models: | ||
- one that is loaded from uncompressed safetensors | ||
- one that is loaded form `compressed-tensors`. | ||
The expectation is for the inference result in same | ||
behavior | ||
""" | ||
from typing import Tuple | ||
|
||
import pytest | ||
from compare_utils import check_logprobs_close | ||
|
||
MODEL_MAX_LEN = 1024 | ||
|
||
# pair of same models with compressed and ordinary safetensors | ||
MODELS = [( | ||
"neuralmagic/llama2.c-stories110M-pruned50", # uncompressed | ||
"dtransposed/llama2.c-stories110M-pruned50-compressed-tensors" | ||
) # compressed | ||
] | ||
|
||
|
||
@pytest.mark.parametrize("model_pair", MODELS) | ||
@pytest.mark.parametrize("dtype", ["float16"]) | ||
@pytest.mark.parametrize("max_tokens", [32]) | ||
@pytest.mark.parametrize("num_logprobs", [3]) | ||
def test_models( | ||
vllm_runner_nm, | ||
example_prompts, | ||
model_pair: Tuple[str, str], | ||
dtype: str, | ||
max_tokens: int, | ||
num_logprobs: int, | ||
) -> None: | ||
|
||
model_uncompressed, model_compressed = model_pair | ||
|
||
vllm_model_0 = vllm_runner_nm(model_uncompressed, | ||
dtype=dtype, | ||
max_model_len=MODEL_MAX_LEN) | ||
|
||
vllm_outputs_0 = vllm_model_0.generate_greedy_logprobs( | ||
example_prompts, max_tokens, num_logprobs) | ||
|
||
del vllm_model_0 | ||
|
||
vllm_model_1 = vllm_runner_nm(model_compressed, | ||
dtype=dtype, | ||
max_model_len=MODEL_MAX_LEN) | ||
|
||
vllm_outputs_1 = vllm_model_1.generate_greedy_logprobs( | ||
example_prompts, max_tokens, num_logprobs) | ||
|
||
del vllm_model_1 | ||
|
||
# loop through the prompts | ||
check_logprobs_close( | ||
outputs_0_lst=vllm_outputs_0, | ||
outputs_1_lst=vllm_outputs_1, | ||
name_0="vllm_model_from_uncompressed_weights", | ||
name_1="vllm_model_from_compressed_weights", | ||
) |
Original file line number | Diff line number | Diff line change | ||||
---|---|---|---|---|---|---|
|
@@ -2,9 +2,11 @@ | |||||
import json | ||||||
import os | ||||||
from dataclasses import dataclass, fields | ||||||
from typing import TYPE_CHECKING, ClassVar, Optional, Union | ||||||
from enum import Enum | ||||||
from typing import TYPE_CHECKING, Any, ClassVar, Dict, Optional, Union | ||||||
|
||||||
import torch | ||||||
from compressed_tensors import SPARSITY_CONFIG_NAME, CompressionConfig | ||||||
from packaging.version import Version | ||||||
from transformers import PretrainedConfig | ||||||
|
||||||
|
@@ -21,6 +23,12 @@ | |||||
_GB = 1 << 30 | ||||||
|
||||||
|
||||||
# UPSTREAM SYNC: keep sparsity | ||||||
class SparsityStructure(Enum): | ||||||
dbogunowicz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
sparse_w16a16 = "sparse_w16a16" | ||||||
semi_structured_sparse_w16a16 = "semi_structured_sparse_w16a16" | ||||||
|
||||||
|
||||||
class ModelConfig: | ||||||
"""Configuration for the model. | ||||||
|
||||||
|
@@ -182,29 +190,71 @@ def _verify_tokenizer_mode(self) -> None: | |||||
|
||||||
# UPSTREAM SYNC: keep sparsity | ||||||
def _verify_sparsity(self) -> None: | ||||||
supported_sparsity = ["sparse_w16a16", "semi_structured_sparse_w16a16"] | ||||||
supported_sparsity = { | ||||||
sparsity_structure.value | ||||||
for sparsity_structure in SparsityStructure | ||||||
} | ||||||
|
||||||
hf_sparsity_config = getattr(self.hf_config, SPARSITY_CONFIG_NAME, | ||||||
dbogunowicz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
None) | ||||||
if hf_sparsity_config is not None: | ||||||
sparsity_structure = self._sparsity_structure_from_config( | ||||||
hf_sparsity_config, dtype=self.dtype) | ||||||
if self.sparsity is not None: | ||||||
logger.info("Overriding the sparsity structure " | ||||||
"inferred from the config: " | ||||||
f"{sparsity_structure} with: {self.sparsity}") | ||||||
self.sparsity = self.sparsity or sparsity_structure | ||||||
if (self.sparsity not in supported_sparsity) and \ | ||||||
(self.sparsity is not None): | ||||||
raise ValueError( | ||||||
f"Unknown sparsity_structure: {self.sparsity}. Must " | ||||||
f"be one of {supported_sparsity}. Running the models " | ||||||
"without sparse kernels.") | ||||||
Comment on lines
+151
to
+154
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This says "Running the models without sparse kernels." but is raising an error - I think this should just warn and continue with unstructured sparse kernels? |
||||||
|
||||||
if self.quantization is not None and self.sparsity is not None: | ||||||
raise ValueError("Both sparsity and quantization detected. Only " | ||||||
"one or the other is supported at a time.") | ||||||
|
||||||
if (self.sparsity is not None | ||||||
and self.sparsity not in supported_sparsity): | ||||||
raise ValueError(f"Unknown sparse method: {self.sparsity}. Must " | ||||||
f"be one of {supported_sparsity}.") | ||||||
@staticmethod | ||||||
def _sparsity_structure_from_config( | ||||||
sparsity_config: Dict[str, Any], | ||||||
dtype: torch.dtype) -> SparsityStructure: | ||||||
""" | ||||||
Translate from the sparsity_config to an appropriate sparsity structure. | ||||||
|
||||||
:param sparsity_config: A dictionary specifying the sparsity config | ||||||
dbogunowicz marked this conversation as resolved.
Show resolved
Hide resolved
|
||||||
:param dtype: The dtype of the model in question | ||||||
:return The appropriate sparsity structure as string | ||||||
""" | ||||||
supported_sparsity_dtypes = {torch.float16, torch.bfloat16} | ||||||
|
||||||
# check the validity of sparsity_config | ||||||
potentially_missing_keys = set(sparsity_config.keys()).difference( | ||||||
CompressionConfig.model_fields.keys()) | ||||||
if potentially_missing_keys: | ||||||
raise ValueError("The detected sparsity_config is " | ||||||
f"missing keys: {potentially_missing_keys}") | ||||||
|
||||||
# check for valid dtype | ||||||
if dtype not in supported_sparsity_dtypes: | ||||||
logger.warning( | ||||||
f"Sparsity is only supported for {supported_sparsity_dtypes}" | ||||||
"dtypes. Running the models without sparse kernels.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I actually meant the current dtype, but supported dtypes are good too!
Suggested change
|
||||||
return None | ||||||
|
||||||
# choose the sparsity structure based on the sparsity config | ||||||
if sparsity_config["sparsity_structure"] in {"unstructured", "0:0"}: | ||||||
return SparsityStructure.sparse_w16a16.value | ||||||
|
||||||
hf_sparsity_config = getattr(self.hf_config, "sparsity_config", None) | ||||||
if hf_sparsity_config is not None: | ||||||
hf_sparsity_method = str( | ||||||
hf_sparsity_config["sparse_method"]).lower() | ||||||
if self.sparsity is None: | ||||||
self.sparsity = hf_sparsity_method | ||||||
elif self.sparsity != hf_sparsity_method: | ||||||
raise ValueError( | ||||||
"Sparsity method specified in the model config " | ||||||
f"({hf_sparsity_method}) does not match the sparsity " | ||||||
f"method specified in the `sparsity` argument " | ||||||
f"({self.sparsity}).") | ||||||
elif sparsity_config["sparsity_structure"] == "2:4": | ||||||
return SparsityStructure.semi_structured_sparse_w16a16.value | ||||||
|
||||||
# if the sparsity config is not recognized, return None | ||||||
logger.warning("The valid sparsity structure cannot be inferred from " | ||||||
"the valid sparsity config:\n{sparsity_config}" | ||||||
"\n Running the models without sparse kernels.") | ||||||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Suggested change
|
||||||
return None | ||||||
|
||||||
def _verify_quantization(self) -> None: | ||||||
supported_quantization = ["awq", "gptq", "squeezellm", "marlin"] | ||||||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Why remove bfloat16 here?