diff options
Diffstat (limited to 'mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp')
| -rw-r--r-- | mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp | 28 |
1 files changed, 21 insertions, 7 deletions
diff --git a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp index 074404add47f..700563460f52 100644 --- a/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp +++ b/mlir/lib/Conversion/AMDGPUToROCDL/AMDGPUToROCDL.cpp @@ -499,7 +499,9 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { /// and LLVM AMDGPU intrinsics convention. /// /// Specifically: -/// 1. If the element type is bfloat16, bitcast it to i16. +/// 1. If the element type is bfloat16, bitcast it to i16 unless rocdl intrinsic +/// allows bf16. Newer MFMAs support bf16 types on operand, check +/// IntrinsicsAMDGPU.td file for reference. /// 2. If instead we have a more than 64-bit quantity, use a <N / 4 x i32> /// instead, which is what the f8f6f4 intrinsics use. /// 3. If `input` is a vector of N <= 8 bytes, bitcast it to a (N * 8)-bit @@ -509,10 +511,11 @@ struct SchedBarrierOpLowering : public ConvertOpToLLVMPattern<SchedBarrierOp> { /// therefore 8-bit and smaller floats are represented as their corresponding /// `iN` integers. static Value convertMFMAVectorOperand(ConversionPatternRewriter &rewriter, - Location loc, Value input) { + Location loc, Value input, + bool allowBf16 = true) { Type inputType = input.getType(); if (auto vectorType = dyn_cast<VectorType>(inputType)) { - if (vectorType.getElementType().isBF16()) + if (vectorType.getElementType().isBF16() && !allowBf16) return rewriter.create<LLVM::BitcastOp>( loc, vectorType.clone(rewriter.getI16Type()), input); if (vectorType.getElementType().isInteger(8) && @@ -958,12 +961,23 @@ struct MFMAOpLowering : public ConvertOpToLLVMPattern<MFMAOp> { StringRef intrinsicName = isScaled ? std::get<0>(*maybeScaledIntrinsic) : *maybeIntrinsic; + // Determine if we can use bf16 in the intrinsic. Newer MFMAs in gfx950+ + // allows bf16 as the input. For reference check IntrinsicsAMDGPU.td file. + bool allowBf16 = [&]() { + if (chipset < kGfx950) + return false; + if (isScaled) + return true; + return intrinsicName.contains("16x16x32.bf16") || + intrinsicName.contains("32x32x16.bf16"); + }(); OperationState loweredOp(loc, intrinsicName); loweredOp.addTypes(intrinsicOutType); - loweredOp.addOperands( - {convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceA()), - convertMFMAVectorOperand(rewriter, loc, adaptor.getSourceB()), - adaptor.getDestC()}); + loweredOp.addOperands({convertMFMAVectorOperand( + rewriter, loc, adaptor.getSourceA(), allowBf16), + convertMFMAVectorOperand( + rewriter, loc, adaptor.getSourceB(), allowBf16), + adaptor.getDestC()}); if (isScaled) { Value zero = createI32Constant(rewriter, loc, 0); auto [_scaledName, aTypeCode, bTypeCode] = *maybeScaledIntrinsic; |
