summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp')
-rw-r--r--mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp34
1 files changed, 31 insertions, 3 deletions
diff --git a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
index 21042aff529c..77f972e0e589 100644
--- a/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
+++ b/mlir/lib/Dialect/AMDGPU/Transforms/EmulateAtomics.cpp
@@ -13,7 +13,9 @@
#include "mlir/Dialect/Arith/IR/Arith.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlow.h"
#include "mlir/Dialect/ControlFlow/IR/ControlFlowOps.h"
+#include "mlir/Dialect/Vector/IR/VectorOps.h"
#include "mlir/IR/BuiltinAttributes.h"
+#include "mlir/IR/TypeUtilities.h"
#include "mlir/Transforms/DialectConversion.h"
namespace mlir::amdgpu {
@@ -86,6 +88,23 @@ static void patchOperandSegmentSizes(ArrayRef<NamedAttribute> attrs,
}
}
+// A helper function to flatten a vector value to a scalar containing its bits,
+// returning the value itself if othetwise.
+static Value flattenVecToBits(ConversionPatternRewriter &rewriter, Location loc,
+ Value val) {
+ auto vectorType = dyn_cast<VectorType>(val.getType());
+ if (!vectorType)
+ return val;
+
+ int64_t bitwidth =
+ vectorType.getElementTypeBitWidth() * vectorType.getNumElements();
+ Type allBitsType = rewriter.getIntegerType(bitwidth);
+ auto allBitsVecType = VectorType::get({1}, allBitsType);
+ Value bitcast = rewriter.create<vector::BitCastOp>(loc, allBitsVecType, val);
+ Value scalar = rewriter.create<vector::ExtractOp>(loc, bitcast, 0);
+ return scalar;
+}
+
template <typename AtomicOp, typename ArithOp>
LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
AtomicOp atomicOp, Adaptor adaptor,
@@ -113,6 +132,7 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
rewriter.setInsertionPointToEnd(loopBlock);
Value prevLoad = loopBlock->getArgument(0);
Value operated = rewriter.create<ArithOp>(loc, data, prevLoad);
+ dataType = operated.getType();
SmallVector<NamedAttribute> cmpswapAttrs;
patchOperandSegmentSizes(origAttrs, cmpswapAttrs, DataArgAction::Duplicate);
@@ -126,8 +146,8 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
// an int->float bitcast is introduced to account for the fact that cmpswap
// only takes integer arguments.
- Value prevLoadForCompare = prevLoad;
- Value atomicResForCompare = atomicRes;
+ Value prevLoadForCompare = flattenVecToBits(rewriter, loc, prevLoad);
+ Value atomicResForCompare = flattenVecToBits(rewriter, loc, atomicRes);
if (auto floatDataTy = dyn_cast<FloatType>(dataType)) {
Type equivInt = rewriter.getIntegerType(floatDataTy.getWidth());
prevLoadForCompare =
@@ -146,9 +166,17 @@ LogicalResult RawBufferAtomicByCasPattern<AtomicOp, ArithOp>::matchAndRewrite(
void mlir::amdgpu::populateAmdgpuEmulateAtomicsPatterns(
ConversionTarget &target, RewritePatternSet &patterns, Chipset chipset) {
// gfx10 has no atomic adds.
- if (chipset >= Chipset(10, 0, 0) || chipset < Chipset(9, 0, 8)) {
+ if (chipset.majorVersion == 10 || chipset < Chipset(9, 0, 8)) {
target.addIllegalOp<RawBufferAtomicFaddOp>();
}
+ // gfx11 has no fp16 atomics
+ if (chipset.majorVersion == 11) {
+ target.addDynamicallyLegalOp<RawBufferAtomicFaddOp>(
+ [](RawBufferAtomicFaddOp op) -> bool {
+ Type elemType = getElementTypeOrSelf(op.getValue().getType());
+ return !isa<Float16Type, BFloat16Type>(elemType);
+ });
+ }
// gfx9 has no to a very limited support for floating-point min and max.
if (chipset.majorVersion == 9) {
if (chipset >= Chipset(9, 0, 0xa) && chipset != Chipset(9, 4, 1)) {