Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial CompressedTensors config + Activation Quantization support for static W8A8 per tensor #195

Merged
merged 30 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
18adcee
initial commit
dsikka Apr 18, 2024
38dcd67
add quant/dequant functions
dsikka Apr 18, 2024
263749a
add csrc files needed for cuda kernels
dsikka Apr 18, 2024
bbe0a70
add updated model runner
dsikka Apr 18, 2024
5a93cb7
add more files
dsikka Apr 18, 2024
e822fef
fix model_runner to match upstream main
dsikka Apr 18, 2024
0c271e4
update
dsikka Apr 18, 2024
3b02d6e
update
dsikka Apr 19, 2024
e09160b
fix model loading
dsikka Apr 19, 2024
dcb1e59
for fake quant, just use torch
dsikka Apr 19, 2024
48956bc
remove if
dsikka Apr 22, 2024
35d2d96
update to run end-to-end; verify with dense matmul for correctness
dsikka Apr 22, 2024
1dfa7f6
update to use ops.quant for weight quantization
dsikka Apr 23, 2024
b2c39a1
fix gibberish
dsikka Apr 23, 2024
e8d1886
Compression config cutlass (#205)
varun-sundar-rabindranath Apr 23, 2024
b840eae
clean-up; separate into separate schemes; add to scheme checking
dsikka Apr 24, 2024
6868f97
format
dsikka Apr 24, 2024
6c89aa9
remove print; update todo
dsikka Apr 24, 2024
a0a9a75
fix rebase
dsikka Apr 24, 2024
14d5f25
update unquant
dsikka Apr 24, 2024
e5f391f
Compression config perf fix (#207)
varun-sundar-rabindranath Apr 25, 2024
ddb10d8
add update supported_list; update params_dtype
dsikka Apr 25, 2024
0c5f2a0
Merge branch 'ds-quant' into compression_config
dsikka Apr 25, 2024
540c159
PR comments
dsikka Apr 29, 2024
677f02c
more comments
dsikka Apr 29, 2024
bd99627
Compression config - cleanup (#215)
varun-sundar-rabindranath Apr 29, 2024
cf61e07
cleanup
dsikka Apr 29, 2024
96fea65
make layer name optional; update create_weights in quant linear methods
dsikka Apr 30, 2024
093e688
cleanup
dsikka Apr 30, 2024
681fb3b
Merge branch 'ds-quant' into compression_config
dsikka Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/compressed_tensors/int8_quant_kernels.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
5 changes: 5 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,11 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void quant_per_tensor(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
Expand Down
9 changes: 9 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -79,6 +79,15 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
&moe_align_block_size,
"Aligning the number of tokens to be processed by each expert such that it is divisible by the block size.");

ops.def(
"quant_per_tensor",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&quant_per_tensor),
"Per-tensor Quantization");


// Cache ops
pybind11::module cache_ops = m.def_submodule("cache_ops", "vLLM cache ops");
cache_ops.def(
Expand Down
50 changes: 50 additions & 0 deletions csrc/quantization/compressed_tensors/int8_quant_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#include <ATen/cuda/CUDAContext.h>
#include <torch/extension.h>
#include <assert.h>

#include "../../dispatch_utils.h"

static inline __device__ int8_t float_to_int8_rn(float x)
{
uint32_t dst;
asm volatile("cvt.rni.sat.s8.f32 %0, %1;" : "=r"(dst) : "f"(x));
return reinterpret_cast<const int8_t&>(dst);
}

namespace vllm {

template <typename scalar_t, typename scale_type>
__global__ void quant_kernel(
const scalar_t* __restrict__ input,
int8_t* __restrict__ out,
scale_type scale,
const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
} // namespace vllm

void quant_per_tensor(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale,
hidden_size);
});
}
1 change: 1 addition & 0 deletions requirements-cuda.txt
Original file line number Diff line number Diff line change
Expand Up @@ -7,3 +7,4 @@ nvidia-ml-py # for pynvml package
vllm-nccl-cu12>=2.18,<2.19 # for downloading nccl library
torch == 2.2.1
xformers == 0.0.25 # Requires PyTorch 2.2.1
nvidia-cutlass == 3.5.0