diff options
Diffstat (limited to 'mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp')
| -rw-r--r-- | mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp | 49 |
1 files changed, 36 insertions, 13 deletions
diff --git a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp index 561a7e569eb2..288f7ab9f302 100644 --- a/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp +++ b/mlir/lib/Dialect/GPU/Transforms/SubgroupReduceLowering.cpp @@ -50,6 +50,8 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { + std::optional<uint32_t> clusterSize = op.getClusterSize(); + auto vecTy = dyn_cast<VectorType>(op.getType()); if (!vecTy || vecTy.getNumElements() < 2) return rewriter.notifyMatchFailure(op, "not a multi-element reduction"); @@ -95,7 +97,7 @@ struct BreakDownSubgroupReduce final : OpRewritePattern<gpu::SubgroupReduceOp> { } Value reduce = rewriter.create<gpu::SubgroupReduceOp>( - loc, extracted, op.getOp(), op.getUniform()); + loc, extracted, op.getOp(), op.getUniform(), clusterSize); if (numElems == 1) { res = rewriter.create<vector::InsertOp>(loc, reduce, res, startIdx); continue; @@ -127,6 +129,8 @@ struct ScalarizeSingleElementReduce final LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { + std::optional<uint32_t> clusterSize = op.getClusterSize(); + auto vecTy = dyn_cast<VectorType>(op.getType()); if (!vecTy || vecTy.getNumElements() != 1) return rewriter.notifyMatchFailure(op, "not a single-element reduction"); @@ -136,7 +140,7 @@ struct ScalarizeSingleElementReduce final Location loc = op.getLoc(); Value extracted = rewriter.create<vector::ExtractOp>(loc, op.getValue(), 0); Value reduce = rewriter.create<gpu::SubgroupReduceOp>( - loc, extracted, op.getOp(), op.getUniform()); + loc, extracted, op.getOp(), op.getUniform(), clusterSize); rewriter.replaceOpWithNewOp<vector::BroadcastOp>(op, vecTy, reduce); return success(); } @@ -147,17 +151,20 @@ struct ScalarizeSingleElementReduce final /// type, respectively. For example, with `input` of type `f16`, `packFn` could /// build ops to cast the value to `i32` to perform shuffles, while `unpackFn` /// would cast it back to `f16` to perform arithmetic reduction on. Assumes that -/// the subgroup is `subgroupSize` lanes wide and reduces across all of them. +/// the subgroup is `subgroupSize` lanes wide and divides it into clusters of +/// `clusterSize` lanes, reducing all lanes in each cluster in parallel. static Value createSubgroupShuffleReduction( OpBuilder &builder, Location loc, Value input, gpu::AllReduceOperation mode, - unsigned subgroupSize, function_ref<Value(Value)> packFn, - function_ref<Value(Value)> unpackFn) { + unsigned clusterSize, unsigned subgroupSize, + function_ref<Value(Value)> packFn, function_ref<Value(Value)> unpackFn) { + assert(llvm::isPowerOf2_32(clusterSize)); assert(llvm::isPowerOf2_32(subgroupSize)); + assert(clusterSize <= subgroupSize); // Lane value always stays in the original type. We use it to perform arith // reductions. Value laneVal = input; // Parallel reduction using butterfly shuffles. - for (unsigned i = 1; i < subgroupSize; i <<= 1) { + for (unsigned i = 1; i < clusterSize; i <<= 1) { Value shuffled = builder .create<gpu::ShuffleOp>(loc, packFn(laneVal), i, /*width=*/subgroupSize, @@ -183,6 +190,13 @@ struct ScalarSubgroupReduceToShuffles final LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { + std::optional<uint32_t> clusterSize = op.getClusterSize(); + if (clusterSize && *clusterSize > subgroupSize) + return op.emitOpError() + << "cluster size " << *clusterSize + << " is greater than subgroup size " << subgroupSize; + unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize); + Type valueTy = op.getType(); unsigned elemBitwidth = getElementTypeOrSelf(valueTy).getIntOrFloatBitWidth(); @@ -196,7 +210,8 @@ struct ScalarSubgroupReduceToShuffles final auto identityFn = [](Value v) { return v; }; rewriter.replaceOp(op, createSubgroupShuffleReduction( rewriter, loc, op.getValue(), op.getOp(), - subgroupSize, identityFn, identityFn)); + effectiveClusterSize, subgroupSize, identityFn, + identityFn)); return success(); } @@ -215,9 +230,10 @@ struct ScalarSubgroupReduceToShuffles final return rewriter.create<arith::BitcastOp>(loc, valueTy, asInt); }; - rewriter.replaceOp(op, createSubgroupShuffleReduction( - rewriter, loc, op.getValue(), op.getOp(), - subgroupSize, packFn, unpackFn)); + rewriter.replaceOp( + op, createSubgroupShuffleReduction(rewriter, loc, op.getValue(), + op.getOp(), effectiveClusterSize, + subgroupSize, packFn, unpackFn)); return success(); } @@ -237,6 +253,13 @@ struct VectorSubgroupReduceToShuffles final LogicalResult matchAndRewrite(gpu::SubgroupReduceOp op, PatternRewriter &rewriter) const override { + std::optional<uint32_t> clusterSize = op.getClusterSize(); + if (clusterSize && *clusterSize > subgroupSize) + return op.emitOpError() + << "cluster size " << *clusterSize + << " is greater than subgroup size " << subgroupSize; + unsigned effectiveClusterSize = clusterSize.value_or(subgroupSize); + auto vecTy = dyn_cast<VectorType>(op.getType()); if (!vecTy) return rewriter.notifyMatchFailure(op, "value type is not a vector"); @@ -285,9 +308,9 @@ struct VectorSubgroupReduceToShuffles final return rewriter.create<vector::BitCastOp>(loc, extendedVecTy, asIntVec); }; - Value res = - createSubgroupShuffleReduction(rewriter, loc, extendedInput, op.getOp(), - subgroupSize, packFn, unpackFn); + Value res = createSubgroupShuffleReduction(rewriter, loc, extendedInput, + op.getOp(), effectiveClusterSize, + subgroupSize, packFn, unpackFn); if (vecBitwidth < shuffleBitwidth) { res = rewriter.create<vector::ExtractStridedSliceOp>( |
