forked from vllm-project/vllm
-
Notifications
You must be signed in to change notification settings - Fork 8
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Initial
CompressedTensors
config + Activation Quantization support …
…for static W8A8 per tensor (#195) - Depending on how we end up parsing `ignore` and `targets` (layer_name vs layer_type) we may not need layer_name to be added to the linear_method. Will experiment using a compressed-tensors function in a follow-up PR - Initial implementation for Compressed Config support + Activation Quantization for static per tensor w8a8 - Includes fused kernels added by @varun-sundar-rabindranath ```python from vllm import LLM, SamplingParams import torch prompts = [ "Hello, my name is", "The capital of France is", "The US president is", "The future of AI is" ] sampling_params = SamplingParams(temperature=0.80, top_p=0.95) llm = LLM(model="nm-testing/tinyllama-one-shot-static-quant-test", enforce_eager=True, dtype=torch.float32, quantization="sparseml") outputs = llm.generate(prompts, sampling_params) for output in outputs: prompt = output.prompt generated_text = output.outputs[0].text print(f"Prompt: {prompt!r}, Generated text: {generated_text!r}") ``` - Verification of the different inputs expected for `targets` and `ignore` --> use functions to parse the layer names which can be shared by sparseml and vllm; would live in compressed tensors (https://github.com/neuralmagic/compressed-tensors/blob/67005d76107d4659787f1efd53fe7e6b1d192818/src/compressed_tensors/quantization/lifecycle/apply.py#L86) - Updates to further optimize fake qunat --------- Co-authored-by: Varun Sundar Rabindranath <varunsundar08@gmail.com> Co-authored-by: Varun Sundar Rabindranath <varun@neuralmagic.com>
- Loading branch information
1 parent
df29793
commit 4d27a2c
Showing
22 changed files
with
691 additions
and
54 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
50 changes: 50 additions & 0 deletions
50
csrc/quantization/compressed_tensors/int8_quant_kernels.cu
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
#include <ATen/cuda/CUDAContext.h> | ||
#include <torch/extension.h> | ||
#include <assert.h> | ||
|
||
#include "../../dispatch_utils.h" | ||
|
||
static inline __device__ int8_t float_to_int8_rn(float x) | ||
{ | ||
uint32_t dst; | ||
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x)); | ||
return reinterpret_cast<const int8_t&>(dst); | ||
} | ||
|
||
namespace vllm { | ||
|
||
template <typename scalar_t, typename scale_type> | ||
__global__ void quant_kernel( | ||
const scalar_t* __restrict__ input, | ||
int8_t* __restrict__ out, | ||
scale_type scale, | ||
const int hidden_size) { | ||
const int tid = threadIdx.x; | ||
const int token_idx = blockIdx.x; | ||
|
||
for (int i = tid; i < hidden_size; i += blockDim.x) { | ||
out[token_idx * hidden_size + i] = | ||
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale); | ||
} | ||
} | ||
} // namespace vllm | ||
|
||
void quant_per_tensor( | ||
torch::Tensor& out, // [..., hidden_size] | ||
torch::Tensor& input, // [..., hidden_size] | ||
float scale) { | ||
assert(input.is_contiguous()); | ||
assert(out.is_contiguous()); | ||
int hidden_size = input.size(-1); | ||
int num_tokens = input.numel() / hidden_size; | ||
dim3 grid(num_tokens); | ||
dim3 block(std::min(hidden_size, 1024)); | ||
const cudaStream_t stream = at::cuda::getCurrentCUDAStream(); | ||
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] { | ||
vllm::quant_kernel<scalar_t, float><<<grid, block, 0, stream>>>( | ||
input.data_ptr<scalar_t>(), | ||
out.data_ptr<int8_t>(), | ||
scale, | ||
hidden_size); | ||
}); | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.