Skip to content

Commit

Permalink
Initial CompressedTensors config + Activation Quantization support …
Browse files Browse the repository at this point in the history
…for static W8A8 per tensor (#195)

- Depending on how we end up parsing `ignore` and `targets` (layer_name
vs layer_type) we may not need layer_name to be added to the
linear_method. Will experiment using a compressed-tensors function in a
follow-up PR

- Initial implementation for Compressed Config support + Activation
Quantization for static per tensor w8a8
- Includes fused kernels added by @varun-sundar-rabindranath

```python
from vllm import LLM, SamplingParams
import torch

prompts = [
    "Hello, my name is",
    "The capital of France is",
    "The US president is",
    "The future of AI is"
]
sampling_params = SamplingParams(temperature=0.80, top_p=0.95)

llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml")

outputs = llm.generate(prompts, sampling_params)
for output in outputs:
    prompt = output.prompt
    generated_text = output.outputs[0].text
    print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}")
```

- Verification of the different inputs expected for `targets` and
`ignore` --> use functions to parse the layer names which can be shared
by sparseml and vllm; would live in compressed tensors
(https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86)
- Updates to further optimize fake qunat

---------

Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com>
Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
  • Loading branch information
3 people committed Apr 30, 2024
1 parent df29793 commit 4d27a2c
Show file tree
Hide file tree
Showing 22 changed files with 691 additions and 54 deletions.
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -156,6 +156,11 @@ void dynamic_scaled_fp8_quant(
torch::Tensor& input,
torch::Tensor& scale);

void quant_per_tensor(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
9 changes: 9 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -80,6 +80,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");

ops.def(
"quant_per_tensor",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&quant_per_tensor),
"Per-tensor Quantization");


// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
Expand Down
50 changes: 50 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <assert.h>

#include "../../dispatch_utils.h"

static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void quant_kernel(
const scalar_t* __restrict__ input,
int8_t* __restrict__ out,
scale_type scale,
const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm

void quant_per_tensor(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale,
hidden_size);
});
}
1 change: 1 addition & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1
nvidia-cutlass == 3.5.0

0 comments on commit 4d27a2c

Please sign in to comment.