summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp')
-rw-r--r--mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp49
1 files changed, 36 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
index 561a7e569eb2..288f7ab9f302 100644
--- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
+++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp
@@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() < 2)
return rewriter.notifyMatchFailure(op, "not a multi-element reduction");
@@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> {
}
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, op.getOp(), op.getUniform());
+ loc, extracted, op.getOp(), op.getUniform(), clusterSize);
if (numElems == 1) {
res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx);
continue;
@@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy || vecTy.getNumElements() != 1)
return rewriter.notifyMatchFailure(op, "not a single-element reduction");
@@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final
Location loc = op.getLoc();
Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0);
Value reduce = rewriter.create<gpu::SubgroupReduceOp>(
- loc, extracted, op.getOp(), op.getUniform());
+ loc, extracted, op.getOp(), op.getUniform(), clusterSize);
rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce);
return success();
}
@@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final
/// type, respectively. For example, with `input` of type `f16`, `packFn` could
/// build ops to cast the value to `i32` to perform shuffles, while `unpackFn`
/// would cast it back to `f16` to perform arithmetic reduction on. Assumes that
-/// the subgroup is `subgroupSize` lanes wide and reduces across all of them.
+/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of
+/// `clusterSize` lanes, reducing all lanes in each cluster in parallel.
static Value createSubgroupShuffleReduction(
OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode,
- unsigned subgroupSize, function_ref<Value(Value)> packFn,
- function_ref<Value(Value)> unpackFn) {
+ unsigned clusterSize, unsigned subgroupSize,
+ function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) {
+ assert(llvm::isPowerOf2_32(clusterSize));
assert(llvm::isPowerOf2_32(subgroupSize));
+ assert(clusterSize <= subgroupSize);
// Lane value always stays in the original type. We use it to perform arith
// reductions.
Value laneVal = input;
// Parallel reduction using butterfly shuffles.
- for (unsigned i = 1; i < subgroupSize; i <<= 1) {
+ for (unsigned i = 1; i < clusterSize; i <<= 1) {
Value shuffled = builder
.create<gpu::ShuffleOp>(loc, packFn(laneVal), i,
/*width=*/subgroupSize,
@@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+ if (clusterSize && *clusterSize > subgroupSize)
+ return op.emitOpError()
+ << "cluster size " << *clusterSize
+ << " is greater than subgroup size " << subgroupSize;
+ unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
Type valueTy = op.getType();
unsigned elemBitwidth =
getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth();
@@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final
auto identityFn = [](Value v) { return v; };
rewriter.replaceOp(op, createSubgroupShuffleReduction(
rewriter, loc, op.getValue(), op.getOp(),
- subgroupSize, identityFn, identityFn));
+ effectiveClusterSize, subgroupSize, identityFn,
+ identityFn));
return success();
}
@@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final
return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt);
};
- rewriter.replaceOp(op, createSubgroupShuffleReduction(
- rewriter, loc, op.getValue(), op.getOp(),
- subgroupSize, packFn, unpackFn));
+ rewriter.replaceOp(
+ op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(),
+ op.getOp(), effectiveClusterSize,
+ subgroupSize, packFn, unpackFn));
return success();
}
@@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final
LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op,
PatternRewriter &rewriter) const override {
+ std::optional<uint32_t> clusterSize = op.getClusterSize();
+ if (clusterSize && *clusterSize > subgroupSize)
+ return op.emitOpError()
+ << "cluster size " << *clusterSize
+ << " is greater than subgroup size " << subgroupSize;
+ unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize);
+
auto vecTy = dyn_cast<VectorType>(op.getType());
if (!vecTy)
return rewriter.notifyMatchFailure(op, "value type is not a vector");
@@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final
return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec);
};
- Value res =
- createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(),
- subgroupSize, packFn, unpackFn);
+ Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput,
+ op.getOp(), effectiveClusterSize,
+ subgroupSize, packFn, unpackFn);
if (vecBitwidth < shuffleBitwidth) {
res = rewriter.create<vector::ExtractStridedSliceOp>(