Skip to content

Commit

Permalink
Make it possible to lower fp8 tt.splat.
Browse files Browse the repository at this point in the history
Before the fix, `tt.splat` was lowered to e.g.
```
%14 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> f8E4M3FNUZ
```
which LLVM rejected.

Translating the result type through typeConverter is what is done in other similar places. It results in
```
%14 = "llvm.mlir.constant"() <{value = 0.000000e+00 : f8E4M3FNUZ}> : () -> i8
```
and it's accepted by LLVM. During the MLIR to LLVM lowering, the fp8 value is converted to i8 with the correct binary representation.

The `isFloat()` function that is updated happened to have just one caller (in ArithConstantSplatOpConversion)

PiperOrigin-RevId: 634065311
  • Loading branch information
mooskagh authored and Copybara-Service committed May 15, 2024
1 parent 05f53e4 commit ee7382e
Show file tree
Hide file tree
Showing 2 changed files with 42 additions and 0 deletions.
41 changes: 41 additions & 0 deletions third_party/triton/temporary/fp8_splat.patch
Original file line number Diff line number Diff line change
@@ -0,0 +1,41 @@
diff --git a/include/triton/Conversion/MLIRTypes.h b/include/triton/Conversion/MLIRTypes.h
--- a/include/triton/Conversion/MLIRTypes.h
+++ b/include/triton/Conversion/MLIRTypes.h
@@ -26,12 +26,6 @@ inline Type f32Ty(MLIRContext *ctx) { re
inline Type f64Ty(MLIRContext *ctx) { return FloatType::getF64(ctx); }
inline Type bf16Ty(MLIRContext *ctx) { return FloatType::getBF16(ctx); }

-inline bool isFloat(Type type) {
- return type.isF32() || type.isF64() || type.isF16() || type.isF128();
-}
-
-inline bool isInt(Type type) { return type.isIntOrFloat() && !isFloat(type); }
-
} // namespace type
} // namespace triton
} // namespace mlir
diff --git a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
--- a/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
+++ b/lib/Conversion/TritonGPUToLLVM/ViewOpToLLVM.cpp
@@ -74,17 +74,18 @@ struct ArithConstantSplatOpConversion
auto values = mlir::dyn_cast<SplatElementsAttr>(op.getValue());
auto elemType = values.getElementType();
Attribute val;
- if (elemType.isBF16() || type::isFloat(elemType)) {
+ if (isa<FloatType>(elemType)) {
val = values.getValues<FloatAttr>()[0];
- } else if (type::isInt(elemType)) {
+ } else if (isa<IntegerType>(elemType)) {
val = values.getValues<IntegerAttr>()[0];
} else {
llvm::errs() << "ArithConstantSplatOpConversion get unsupported type: "
<< value.getType() << "\n";
return failure();
}
- auto constOp = rewriter.create<LLVM::ConstantOp>(loc, elemType, val);
auto typeConverter = getTypeConverter();
+ auto constOp = rewriter.create<LLVM::ConstantOp>(
+ loc, typeConverter->convertType(elemType), val);
auto llStruct = SplatOpConversion::convertSplatLikeOp(
elemType, op.getType(), constOp, typeConverter, rewriter, loc);
rewriter.replaceOp(op, llStruct);
1 change: 1 addition & 0 deletions third_party/triton/temporary/series.bzl
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ internal patch during the next triton integration process.
"""

temporary_patch_list = [
"//third_party/triton/temporary:fp8_splat.patch",
"//third_party/triton/temporary:mma_limit_pred.patch",
"//third_party/triton/temporary:fix_register_constraints.patch",
]

0 comments on commit ee7382e

Please sign in to comment.