summaryrefslogtreecommitdiff
path: root/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
-rw-r--r--mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp28
1 files changed, 21 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
index 074404add47f..700563460f52 100644
--- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
+++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp
@@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// and LLVM AMDGPU intrinsics convention.
///
/// Specifically:
-/// 1. If the element type is bfloat16, bitcast it to i16.
+/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic
+/// allows bf16. Newer MFMAs support bf16 types on operand, check
+/// IntrinsicsAMDGPU.td file for reference.
/// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32>
/// instead, which is what the f8f6f4 intrinsics use.
/// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit
@@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> {
/// therefore 8-bit and smaller floats are represented as their corresponding
/// `iN` integers.
static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter,
- Location loc, Value input) {
+ Location loc, Value input,
+ bool allowBf16 = true) {
Type inputType = input.getType();
if (auto vectorType = dyn_cast<VectorType>(inputType)) {
- if (vectorType.getElementType().isBF16())
+ if (vectorType.getElementType().isBF16() && !allowBf16)
return rewriter.create<LLVM::BitcastOp>(
loc, vectorType.clone(rewriter.getI16Type()), input);
if (vectorType.getElementType().isInteger(8) &&
@@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> {
StringRef intrinsicName =
isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic;
+ // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+
+ // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file.
+ bool allowBf16 = [&]() {
+ if (chipset < kGfx950)
+ return false;
+ if (isScaled)
+ return true;
+ return intrinsicName.contains("16x16x32.bf16") ||
+ intrinsicName.contains("32x32x16.bf16");
+ }();
OperationState loweredOp(loc, intrinsicName);
loweredOp.addTypes(intrinsicOutType);
- loweredOp.addOperands(
- {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()),
- convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()),
- adaptor.getDestC()});
+ loweredOp.addOperands({convertMFMAVectorOperand(
+ rewriter, loc, adaptor.getSourceA(), allowBf16),
+ convertMFMAVectorOperand(
+ rewriter, loc, adaptor.getSourceB(), allowBf16),
+ adaptor.getDestC()});
if (isScaled) {
Value zero = createI32Constant(rewriter, loc, 0);
auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic;