Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Draft] [FP8] CUTLASS FP8 matrix multiply #4662

Draft
wants to merge 10 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
3 changes: 3 additions & 0 deletions .gitmodules
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
[submodule "third_party/cutlass"]
path = third_party/cutlass
url = https://github.com/NVIDIA/cutlass.git
4 changes: 3 additions & 1 deletion CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@ include(${CMAKE_CURRENT_LIST_DIR}/cmake/utils.cmake)
set(PYTHON_SUPPORTED_VERSIONS "3.8" "3.9" "3.10" "3.11")

# Supported NVIDIA architectures.
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0")
set(CUDA_SUPPORTED_ARCHS "7.0;7.5;8.0;8.6;8.9;9.0;9.0a")

# Supported AMD GPU architectures.
set(HIP_SUPPORTED_ARCHS "gfx906;gfx908;gfx90a;gfx940;gfx941;gfx942;gfx1030;gfx1100")
Expand Down Expand Up @@ -191,6 +191,8 @@ define_gpu_extension_target(
ARCHITECTURES ${VLLM_GPU_ARCHES}
WITH_SOABI)

target_include_directories(_C PRIVATE "third_party/cutlass/include" "third_party/cutlass/tools/util/include")

#
# _moe_C extension
#
Expand Down
6 changes: 4 additions & 2 deletions cmake/utils.cmake
Original file line number Diff line number Diff line change
Expand Up @@ -207,8 +207,10 @@ macro(override_gpu_arches GPU_ARCHES GPU_LANG GPU_SUPPORTED_ARCHES)
"CMAKE_CUDA_FLAGS. (${CMAKE_CUDA_FLAGS})")
endif()

message(DEBUG "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(DEBUG "arch flags: ${_CUDA_ARCH_FLAGS}")
message(STATUS "final CMAKE_CUDA_FLAGS: ${CMAKE_CUDA_FLAGS}")
message(STATUS "arch flags: ${_CUDA_ARCH_FLAGS}")
set(_CUDA_ARCH_FLAGS "-gencode arch=compute_90a,code=sm_90a")
message(STATUS "arch flags: ${_CUDA_ARCH_FLAGS}")

# Initialize the architecture lists to empty.
set(${GPU_ARCHES})
Expand Down
6 changes: 6 additions & 0 deletions csrc/ops.h
Original file line number Diff line number Diff line change
Expand Up @@ -176,6 +176,12 @@ void dynamic_scaled_fp8_quant(
torch::Tensor& input,
torch::Tensor& scale);

void fp8_scaled_gemm(
torch::Tensor& out,
torch::Tensor& input,
torch::Tensor& weights,
torch::Tensor& workspace);

void moe_align_block_size(
torch::Tensor topk_ids,
int num_experts,
Expand Down
1 change: 1 addition & 0 deletions csrc/pybind.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,7 @@ PYBIND11_MODULE(TORCH_EXTENSION_NAME, m) {
ops.def("squeezellm_gemm", &squeezellm_gemm, "Quantized GEMM for SqueezeLLM");
ops.def("static_scaled_fp8_quant", &static_scaled_fp8_quant, "Compute FP8 quantized tensor for given scaling factor");
ops.def("dynamic_scaled_fp8_quant", &dynamic_scaled_fp8_quant, "Compute FP8 quantized tensor and scaling factor");
ops.def("fp8_scaled_gemm", &fp8_scaled_gemm, "Matrix multiplication with FP8 scaling factors");
ops.def(
"moe_align_block_size",
&moe_align_block_size,
Expand Down
36 changes: 35 additions & 1 deletion csrc/quantization/fp8/common.cu
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
#include "cuda_compat.h"
#include "dispatch_utils.h"

#include "fp8_gemm_kernels.h"

namespace vllm {

__device__ __forceinline__ float atomicMaxFloat(float* addr, float value) {
Expand Down Expand Up @@ -45,7 +47,7 @@ __global__ void segmented_max_reduction(
scalar_t tmp = 0.0;
while (i < num_elems) {
float x = static_cast<float>(input[i]);
tmp = max(tmp, fabs(x));
tmp = fmax(tmp, fabs(x));
i += blockDim.x * gridDim.x;
}
cache[threadIdx.x] = tmp;
Expand Down Expand Up @@ -133,3 +135,35 @@ void dynamic_scaled_fp8_quant(
});
}

void fp8_scaled_gemm(torch::Tensor& out, torch::Tensor& input, torch::Tensor& weights, torch::Tensor& workspace) {
Gemm gemm;

int m = input.size(0);
int n = weights.size(1);
int k = weights.size(0);
int l = 1;

StrideA stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(m, k, l));
StrideB stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(n, k, l));
StrideC stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(m, n, l));
StrideD stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(m, n, l));

typename Gemm::Arguments arguments{
cutlass::gemm::GemmUniversalMode::kGemm,
{m, n, k, l},
{reinterpret_cast<cutlass::float_e4m3_t*>(input.data_ptr<c10::Float8_e4m3fn>()), stride_A,
reinterpret_cast<cutlass::float_e4m3_t*>(weights.data_ptr<c10::Float8_e4m3fn>()), stride_B},
{
{1.0f, 0.0f}, // epilogue.thread
reinterpret_cast<cutlass::half_t*>(out.data_ptr<c10::Half>()), stride_C,
reinterpret_cast<cutlass::half_t*>(out.data_ptr<c10::Half>()), stride_D
}
};

size_t workspace_size = Gemm::get_workspace_size(arguments);
TORCH_CHECK(workspace.numel() >= workspace_size);
TORCH_CHECK(gemm.can_implement(arguments) == cutlass::Status::kSuccess);
TORCH_CHECK(gemm.initialize(arguments, workspace.data_ptr<uint8_t>()) == cutlass::Status::kSuccess);
TORCH_CHECK(gemm.run() == cutlass::Status::kSuccess);
}

80 changes: 80 additions & 0 deletions csrc/quantization/fp8/fp8_gemm_kernels.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
#include "cutlass/cutlass.h"
#include "cutlass/numeric_types.h"

#include "cute/tensor.hpp"
#include "cutlass/gemm/collective/collective_builder.hpp"
#include "cutlass/gemm/device/gemm_universal_adapter.h"
#include "cutlass/gemm/kernel/gemm_universal.hpp"
#include "cutlass/epilogue/collective/collective_builder.hpp"
#include "cutlass/util/packed_stride.hpp"

using namespace cute;

// A matrix configuration
using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand
using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand
constexpr int AlignmentA = 128 / cutlass::sizeof_bits<ElementA>::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes)

// B matrix configuration
using ElementB = cutlass::float_e4m3_t; // Element type for B matrix operand
using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand
constexpr int AlignmentB = 128 / cutlass::sizeof_bits<ElementB>::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes)

// C matrix configuration
using ElementC = cutlass::half_t; // Element type for C and D matrix operands
using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands
constexpr int AlignmentC = 128 / cutlass::sizeof_bits<ElementC>::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes)

// D matrix configuration
using ElementD = ElementC;
using LayoutD = LayoutC;
constexpr int AlignmentD = AlignmentC;

// Core kernel configurations
using ElementAccumulator = float; // Element type for internal accumulation
using ElementCompute = float; // Element type for epilogue computation
using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature
using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag
using TileShape = Shape<_64,_64,_256>; // Threadblock-level tile size
using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster
using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized;
using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized;
using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto;

using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder<
ArchTag, OperatorClass,
TileShape, ClusterShape,
EpilogueTileType,
ElementAccumulator, ElementCompute,
void, LayoutC, AlignmentC,
ElementD, LayoutD, AlignmentD,
EpilogueSchedule
>::CollectiveOp;

using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder<
ArchTag, OperatorClass,
ElementA, LayoutA, AlignmentA,
ElementB, LayoutB, AlignmentB,
ElementAccumulator,
TileShape, ClusterShape,
cutlass::gemm::collective::StageCountAutoCarveout<
static_cast<int>(sizeof(typename CollectiveEpilogue::SharedStorage))
>,
KernelSchedule
>::CollectiveOp;

using GemmKernel = cutlass::gemm::kernel::GemmUniversal<
Shape<int,int,int,int>, // Indicates ProblemShape
CollectiveMainloop,
CollectiveEpilogue
>;

using Gemm = cutlass::gemm::device::GemmUniversalAdapter<GemmKernel>;

using StrideA = typename Gemm::GemmKernel::StrideA;
using StrideB = typename Gemm::GemmKernel::StrideB;
using StrideC = typename Gemm::GemmKernel::StrideC;
using StrideD = typename Gemm::GemmKernel::StrideD;



1 change: 1 addition & 0 deletions third_party/cutlass
Submodule cutlass added at 033d9e