Prototype FP8Linear W8A8 runtime quantization #190
+190
−12
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
Adds FP8 quantization at runtime for both weights and activations using
torch.float8_e4m3fn
torch._scaled_mm
provides an W8A8 linear kernel for FP8, but is only supported on CUDA devices with compute capability >= 9.0 for torch==2.2.1.It has been expanded to CUDA 8.9, or ROCm MI300+ on main, but won't be on a stable release for a while.
This means for CUDA devices with compute capability < 9.0 (currently everything below Hopper), the weights will be dequantized into higher precision offering no compute savings.
Original precision bfloat16:
Quantized to FP8, specifically float8_e4m3fn: