diff options
Diffstat (limited to 'mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp')
| -rw-r--r-- | mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp | 34 |
1 files changed, 31 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp index 21042aff529c..77f972e0e589 100644 --- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp +++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp @@ -13,7 +13,9 @@ #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlow.h" #include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h" +#include "mlir/Dialect/Vector/IR/VectorOps.h" #include "mlir/IR/BuiltinAttributes.h" +#include "mlir/IR/TypeUtilities.h" #include "mlir/Transforms/DialectConversion.h" namespace mlir::amdgpu { @@ -86,6 +88,23 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs, } } +// A helper function to flatten a vector value to a scalar containing its bits, +// returning the value itself if othetwise. +static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc, + Value val) { + auto vectorType = dyn_cast<VectorType>(val.getType()); + if (!vectorType) + return val; + + int64_t bitwidth = + vectorType.getElementTypeBitWidth() * vectorType.getNumElements(); + Type allBitsType = rewriter.getIntegerType(bitwidth); + auto allBitsVecType = VectorType::get({1}, allBitsType); + Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val); + Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0); + return scalar; +} + template <typename AtomicOp, typename ArithOp> LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( AtomicOp atomicOp, Adaptor adaptor, @@ -113,6 +132,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( rewriter.setInsertionPointToEnd(loopBlock); Value prevLoad = loopBlock->getArgument(0); Value operated = rewriter.create<ArithOp>(loc, data, prevLoad); + dataType = operated.getType(); SmallVector<NamedAttribute> cmpswapAttrs; patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate); @@ -126,8 +146,8 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( // an int->float bitcast is introduced to account for the fact that cmpswap // only takes integer arguments. - Value prevLoadForCompare = prevLoad; - Value atomicResForCompare = atomicRes; + Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad); + Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes); if (auto floatDataTy = dyn_cast<FloatType>(dataType)) { Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth()); prevLoadForCompare = @@ -146,9 +166,17 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite( void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns( ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) { // gfx10 has no atomic adds. - if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) { + if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) { target.addIllegalOp<RawBufferAtomicFaddOp>(); } + // gfx11 has no fp16 atomics + if (chipset.majorVersion == 11) { + target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>( + [](RawBufferAtomicFaddOp op) -> bool { + Type elemType = getElementTypeOrSelf(op.getValue().getType()); + return !isa<Float16Type, BFloat16Type>(elemType); + }); + } // gfx9 has no to a very limited support for floating-point min and max. if (chipset.majorVersion == 9) { if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) { |
