Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Initial CompressedTensors config + Activation Quantization support for static W8A8 per tensor #195

Merged
merged 30 commits into from
Apr 30, 2024
Merged
Show file tree
Hide file tree
Changes from 23 commits
Commits
Show all changes
30 commits
Select commit Hold shift + click to select a range
18adcee
initial commit
dsikka Apr 18, 2024
38dcd67
add quant/dequant functions
dsikka Apr 18, 2024
263749a
add csrc files needed for cuda kernels
dsikka Apr 18, 2024
bbe0a70
add updated model runner
dsikka Apr 18, 2024
5a93cb7
add more files
dsikka Apr 18, 2024
e822fef
fix model_runner to match upstream main
dsikka Apr 18, 2024
0c271e4
update
dsikka Apr 18, 2024
3b02d6e
update
dsikka Apr 19, 2024
e09160b
fix model loading
dsikka Apr 19, 2024
dcb1e59
for fake quant, just use torch
dsikka Apr 19, 2024
48956bc
remove if
dsikka Apr 22, 2024
35d2d96
update to run end-to-end; verify with dense matmul for correctness
dsikka Apr 22, 2024
1dfa7f6
update to use ops.quant for weight quantization
dsikka Apr 23, 2024
b2c39a1
fix gibberish
dsikka Apr 23, 2024
e8d1886
Compression config cutlass (#205)
varun-sundar-rabindranath Apr 23, 2024
b840eae
clean-up; separate into separate schemes; add to scheme checking
dsikka Apr 24, 2024
6868f97
format
dsikka Apr 24, 2024
6c89aa9
remove print; update todo
dsikka Apr 24, 2024
a0a9a75
fix rebase
dsikka Apr 24, 2024
14d5f25
update unquant
dsikka Apr 24, 2024
e5f391f
Compression config perf fix (#207)
varun-sundar-rabindranath Apr 25, 2024
ddb10d8
add update supported_list; update params_dtype
dsikka Apr 25, 2024
0c5f2a0
Merge branch 'ds-quant' into compression_config
dsikka Apr 25, 2024
540c159
PR comments
dsikka Apr 29, 2024
677f02c
more comments
dsikka Apr 29, 2024
bd99627
Compression config - cleanup (#215)
varun-sundar-rabindranath Apr 29, 2024
cf61e07
cleanup
dsikka Apr 29, 2024
96fea65
make layer name optional; update create_weights in quant linear methods
dsikka Apr 30, 2024
093e688
cleanup
dsikka Apr 30, 2024
681fb3b
Merge branch 'ds-quant' into compression_config
dsikka Apr 30, 2024
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
1 change: 1 addition & 0 deletions CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -167,6 +167,7 @@ set(VLLM_EXT_SRC
"csrc/layernorm_kernels.cu"
"csrc/quantization/squeezellm/quant_cuda_kernel.cu"
"csrc/quantization/gptq/q_gemm.cu"
"csrc/quantization/smoothquant/fused_kernels.cu"
"csrc/quantization/fp8/fp8_cuda_kernels.cu"
"csrc/cuda_utils_kernels.cu"
"csrc/moe_align_block_size_kernels.cu"
Expand Down
1 change: 1 addition & 0 deletions csrc/attention/attention_dtypes.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,3 +5,4 @@
#include "dtype_float32.cuh"
#include "dtype_bfloat16.cuh"
#include "dtype_fp8.cuh"
#include "dtype_int8.cuh"
8 changes: 8 additions & 0 deletions csrc/attention/dtype_float32.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -86,6 +86,14 @@ inline __device__ float4 add(float4 a, float4 b) {
return c;
}

// for compiling, the above function seems to be useless
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
inline __device__ Float4_ add(Float4_ a, Float4_ b) {
Float4_ c;
c.x = add(a.x, b.x);
c.y = add(a.y, b.y);
return c;
}

// Vector multiplication.
template<>
inline __device__ float mul<float, float>(float a, float b) {
Expand Down
49 changes: 49 additions & 0 deletions csrc/attention/dtype_int8.cuh
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
#pragma once
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved

#include <stdint.h>
#include "attention_generic.cuh"
#include "dtype_float32.cuh"

namespace vllm {
// define int8 vector types for quantization of kv cache

template<>
struct Vec<int8_t, 1> {
using Type = int8_t;
};

template<>
struct Vec<int8_t, 2> {
using Type = int16_t;
};

template<>
struct Vec<int8_t, 4> {
using Type = int32_t;
};

template<>
struct Vec<int8_t, 8> {
using Type = int64_t;
};

template<>
struct FloatVec<int8_t> {
using Type = float;
};

template<>
struct FloatVec<int16_t> {
using Type = float2;
};

template<>
struct FloatVec<int32_t> {
using Type = Float4_;
};

template<>
struct FloatVec<int64_t> {
using Type = Float8_;
};
}
10 changes: 9 additions & 1 deletion csrc/dispatch_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,20 @@
#define VLLM_DISPATCH_CASE_FLOATING_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__)
AT_DISPATCH_CASE(at::ScalarType::BFloat16, __VA_ARGS__) \

#define VLLM_DISPATCH_CASE_QUANT_TYPES(...) \
VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Char, __VA_ARGS__)

#define VLLM_DISPATCH_FLOATING_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_FLOATING_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_QUANT_TYPES(TYPE, NAME, ...) \
AT_DISPATCH_SWITCH( \
TYPE, NAME, VLLM_DISPATCH_CASE_QUANT_TYPES(__VA_ARGS__))

#define VLLM_DISPATCH_CASE_FLOATING_AND_BYTE_TYPES(...) \
AT_DISPATCH_CASE(at::ScalarType::Float, __VA_ARGS__) \
AT_DISPATCH_CASE(at::ScalarType::Half, __VA_ARGS__) \
Expand Down
10 changes: 10 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -146,6 +146,16 @@ void gptq_shuffle(
torch::Tensor q_perm,
int bit);

void quant(
torch::Tensor& out,
torch::Tensor& input,
float scale);

void quant(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& scale);

void scaled_fp8_quant(
torch::Tensor& out,
torch::Tensor& input,
Expand Down
15 changes: 15 additions & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,21 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
"fused_add_rms_norm",
&fused_add_rms_norm,
"In-place fused Add and RMS Normalization");
ops.def(
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
"quant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
float>(&quant),
"Quant.");
ops.def(
"quant",
py::overload_cast<
torch::Tensor&,
torch::Tensor&,
torch::Tensor&>(
&quant),
"Per-token quant.");

// Rotary embedding
ops.def(
Expand Down
91 changes: 91 additions & 0 deletions csrc/quantization/smoothquant/fused_kernels.cu
Original file line number Diff line number Diff line change
@@ -0,0 +1,91 @@
#include <ATen/cuda/CUDAContext.h>
varun-sundar-rabindranath marked this conversation as resolved.
Show resolved Hide resolved
#include <torch/extension.h>
#include <assert.h>

#include "../../dispatch_utils.h"
#include "../../reduction_utils.cuh"
#include "quant_utils.cuh"

namespace vllm {

template <typename scalar_t, typename scale_type, bool use_per_token_quant>
__global__ void quant_kernel(
const scalar_t* __restrict__ input,
int8_t* __restrict__ out,
scale_type scale,
const int hidden_size) {
const int tid = threadIdx.x;
const int token_idx = blockIdx.x;

if constexpr (use_per_token_quant) {
float amax_val = 0.0f;
const float zero = 0.0f;

for (int i = tid; i < hidden_size; i += blockDim.x) {
float val = (float)input[token_idx * hidden_size + i];
val = val > zero ? val : -val;
if (val > amax_val)
amax_val = val;
}

__shared__ float s_amax;
const float block_amax_val = blockReduceMax(amax_val);
if (tid == 0) {
s_amax = block_amax_val;
scale[token_idx] = block_amax_val / 127.0f;
}
__syncthreads();

float tmp_scale = 127.0f / s_amax;
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) * tmp_scale);
}
} else {
for (int i = tid; i < hidden_size; i += blockDim.x) {
out[token_idx * hidden_size + i] =
float_to_int8_rn(((float)input[token_idx * hidden_size + i]) / scale);
}
}
}
} // namespace vllm

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
float scale) {
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float, false><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale,
hidden_size);
});
}

void quant(
torch::Tensor& out, // [..., hidden_size]
torch::Tensor& input, // [..., hidden_size]
torch::Tensor& scale) { // [num_tokens]
assert(input.is_contiguous());
assert(out.is_contiguous());
int hidden_size = input.size(-1);
int num_tokens = input.numel() / hidden_size;
dim3 grid(num_tokens);
dim3 block(std::min(hidden_size, 1024));
const cudaStream_t stream = at::cuda::getCurrentCUDAStream();
VLLM_DISPATCH_FLOATING_TYPES(input.scalar_type(), "quant_kernel", [&] {
vllm::quant_kernel<scalar_t, float*, true><<<grid, block, 0, stream>>>(
input.data_ptr<scalar_t>(),
out.data_ptr<int8_t>(),
scale.data_ptr<float>(),
hidden_size);
});
}