Skip to content

Commit

Permalink
Add naive dequant
Browse files Browse the repository at this point in the history
  • Loading branch information
mgoin committed Apr 15, 2024
1 parent 46b49ac commit 0373b69
Showing 1 changed file with 33 additions and 11 deletions.
44 changes: 33 additions & 11 deletions vllm/model_executor/layers/quantization/fp8.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,30 @@
from functools import lru_cache
from typing import Any, Dict, List, Optional

import torch
import torch.nn.functional as F
from magic_wand import CompressedStorageFormat

from vllm.logger import init_logger
from vllm.model_executor.layers.linear import (LinearMethodBase,
set_weight_attrs)
from vllm.model_executor.layers.parameters import LazyCompressedParameter
from vllm.model_executor.layers.quantization.base_config import (
QuantizationConfig)

logger = init_logger(__name__)


@lru_cache(None)
def warn_once(msg):
logger.warning(msg)


def fp8_quantize(
weight,
qdtype: torch.dtype = torch.float8_e4m3fn
) -> tuple[torch.Tensor, torch.Tensor]:
dtype = weight.dtype
finfo = torch.finfo(qdtype)
# Calculate the scale as dtype max divided by absmax
scale = finfo.max / weight.abs().max().clamp(min=1e-12)
Expand All @@ -24,7 +35,7 @@ def fp8_quantize(
# Return both float8 data and the inverse scale (as float),
# as both required as inputs to torch._scaled_mm
qweight = qweight.to(qdtype)
scale = scale.float().reciprocal()
scale = scale.to(dtype).reciprocal()
return qweight, scale


Expand Down Expand Up @@ -72,8 +83,8 @@ def get_supported_act_dtypes(self) -> List[torch.dtype]:
def get_min_capability(self) -> int:
# FP8 hardware support is required because
# torch._scaled_mm is only supported on CUDA devices with
# compute capability >= 9.0 or 8.9, or ROCm MI300+");
return 89
# compute capability >= 9.0 or 8.9, or ROCm MI300+
return 70

@staticmethod
def get_config_filenames() -> List[str]:
Expand Down Expand Up @@ -140,13 +151,24 @@ def apply_weights(self,

qx, xscale = fp8_quantize(x)

output, _ = torch._scaled_mm(
qx,
w.compressed_data.values,
out_dtype=self.dtype,
scale_a=xscale,
scale_b=w.compressed_data.scale,
bias=bias,
)
cuda_compute_capability = torch.cuda.get_device_capability()
if cuda_compute_capability >= (9, 0):
output, _ = torch._scaled_mm(
qx,
w.compressed_data.values,
out_dtype=self.dtype,
scale_a=xscale,
scale_b=w.compressed_data.scale,
bias=bias,
)
else:
# For NVIDIA SM < 9.0
warn_once("FP8 hardware support doesn't exist for "
"NVIDIA SM < 9.0. Up-conversion to "
"original dtype will be used.")

output = F.linear(
qx.to(self.dtype) * xscale, w.compressed_data.decompress(),
bias)

return output

1 comment on commit 0373b69

@github-actions
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

bigger_is_better

Benchmark suite Current: 0373b69 Previous: 788b4e5 Ratio
{"name": "request_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, Mar 7 2024, 18:39:53) [GCC 9.4.0]", "torch_version": "2.2.1+cu121"} 3.8013765409089246 prompts/s 3.818269346642975 prompts/s 1.00
{"name": "token_throughput", "description": "VLLM Engine throughput - synthetic\nmodel - NousResearch/Llama-2-7b-chat-hf\nmax_model_len - 4096\nbenchmark_throughput {\n \"use-all-available-gpus_\": \"\",\n \"input-len\": 256,\n \"output-len\": 128,\n \"num-prompts\": 1000\n}", "gpu_description": "NVIDIA A10G x 1", "vllm_version": "0.2.0", "python_version": "3.10.12 (main, Mar 7 2024, 18:39:53) [GCC 9.4.0]", "torch_version": "2.2.1+cu121"} 1459.7285917090271 tokens/s 1466.2154291109025 tokens/s 1.00

This comment was automatically generated by workflow using github-action-benchmark.

Please sign in to comment.