diff options
| author | Nishant Patel <nishant.b.patel@intel.com> | 2025-11-20 15:00:57 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-20 15:00:57 -0800 |
| commit | 310abe0e4b4ebb57976928cc0b520f9e878b54a7 (patch) | |
| tree | 516d6026d6e836e990663334dd2eb081158755d8 /mlir | |
| parent | fbc093588f654ba771dfc055687676edf4d76884 (diff) | |
[MLIR] [XeGPU] Add distribution pattern for vector.constant_mask from Wg To Sg (#168118)
Diffstat (limited to 'mlir')
3 files changed, 113 insertions, 5 deletions
diff --git a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp index 33d4b0457e5d..c6ace1802bc4 100644 --- a/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp +++ b/mlir/lib/Dialect/XeGPU/Transforms/XeGPUWgToSgDistribute.cpp @@ -1285,6 +1285,71 @@ struct WgToSgVectorTransposeOp } }; +// This pattern distributes the vector.constant_mask ops to work at subgroup +// level. +struct WgToSgVectorConstantMaskOp + : public OpConversionPattern<vector::ConstantMaskOp> { + using OpConversionPattern<vector::ConstantMaskOp>::OpConversionPattern; + + LogicalResult + matchAndRewrite(vector::ConstantMaskOp op, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { + xegpu::DistributeLayoutAttr layout = + xegpu::getDistributeLayoutAttr(op.getResult()); + if (!layout || !layout.isForWorkgroup()) + return failure(); + + Location loc = op.getLoc(); + VectorType type = op.getResult().getType(); + auto wgShape = type.getShape(); + + ArrayRef<int64_t> wgMaskDimSizes = op.getMaskDimSizes(); + + // Get subgroup ID. + Value sgId = + gpu::SubgroupIdOp::create(rewriter, loc, /*upper_bound=*/nullptr); + auto sgOffsets = + layout.computeDistributedCoords(rewriter, loc, sgId, wgShape); + if (failed(sgOffsets)) + return failure(); + + SmallVector<int64_t> sgShape = getSgShapeAndCount(wgShape, layout).first; + VectorType resultType = VectorType::get(sgShape, type.getElementType()); + + // In each dimension, each subgroup computes its local mask size as: + // min(max(wgMaskSize[d] - offset[d], 0), sgDimSize[d]) + SmallVector<Value> newCreateMaskOps; + for (auto offsetSet : *sgOffsets) { + SmallVector<Value> maskOperands; + + for (auto [i, wgMaskSize] : llvm::enumerate(wgMaskDimSizes)) { + Value wgMaskSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, wgMaskSize); + Value dimSizeVal = + arith::ConstantIndexOp::create(rewriter, loc, sgShape[i]); + Value offset = offsetSet[i]; + Value adjustedMaskSize = + arith::SubIOp::create(rewriter, loc, wgMaskSizeVal, offset); + Value zero = arith::ConstantIndexOp::create(rewriter, loc, 0); + Value nonNegative = + arith::MaxSIOp::create(rewriter, loc, adjustedMaskSize, zero); + Value sgMaskSize = + arith::MinSIOp::create(rewriter, loc, nonNegative, dimSizeVal); + maskOperands.push_back(sgMaskSize); + } + + auto newCreateMaskOp = + vector::CreateMaskOp::create(rewriter, loc, resultType, maskOperands); + xegpu::setDistributeLayoutAttr(newCreateMaskOp->getResult(0), + layout.dropSgLayoutAndData()); + newCreateMaskOps.push_back(newCreateMaskOp.getResult()); + } + + rewriter.replaceOpWithMultiple(op, {newCreateMaskOps}); + return success(); + } +}; + } // namespace namespace mlir { @@ -1299,8 +1364,8 @@ void populateXeGPUWgToSgDistributePatterns(RewritePatternSet &patterns) { WgToSgArithConstantOp, WgToSgLoadGatherOpWithOffset, WgToSgStoreScatterOpWithOffset, WgToSgLoadMatrixOp, WgToSgStoreMatrixOp, WgToSgVectorStepOp, WgToSgVectorShapeCastOp, - WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp>( - patterns.getContext()); + WgToSgMultiDimReductionOp, WgToSgVectorTransposeOp, + WgToSgVectorConstantMaskOp>(patterns.getContext()); } } // namespace xegpu } // namespace mlir @@ -1427,9 +1492,9 @@ void XeGPUWgToSgDistributePass::runOnOperation() { return isLegal(layout); }); - target.addDynamicallyLegalOp<vector::ShapeCastOp, vector::StepOp, - vector::TransposeOp, vector::BroadcastOp, - vector::MultiDimReductionOp>( + target.addDynamicallyLegalOp< + vector::ShapeCastOp, vector::StepOp, vector::TransposeOp, + vector::BroadcastOp, vector::MultiDimReductionOp, vector::ConstantMaskOp>( [=](Operation *op) -> bool { // Check for either a SliceAttr or LayoutAttr on the result. auto layout = xegpu::getDistributeLayoutAttr(op->getResult(0)); diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir index 84ce80f477a5..1cddccb5fbbd 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops-rr.mlir @@ -130,5 +130,13 @@ gpu.module @test_distribution { %trans = vector.transpose %load, [1, 0] {layout_result_0 = #xegpu.layout<sg_layout = [4, 8], sg_data = [16, 32], lane_layout = [1, 16], lane_data = [1, 1], order =[1, 0]>} : vector<256x128xf32> to vector<128x256xf32> gpu.return } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-COUNT-4: vector.create_mask {{.*}}, {{.*}} : vector<16x16xi1> + // CHECK-NOT: vector.create_mask + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [16, 16]>} : vector<256x128xi1> + gpu.return + } } diff --git a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir index 5dde84e8e0bc..574b365443a0 100644 --- a/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir +++ b/mlir/test/Dialect/XeGPU/xegpu-wg-to-sg-unify-ops.mlir @@ -548,6 +548,41 @@ gpu.module @test_distribution { gpu.return } + // CHECK-LABEL: vector_mask_1D + gpu.func @vector_mask_1D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[REMU:.*]] = index.remu %[[SGID]], %[[C2:.*]] + // CHECK-DAG: %[[MUL:.*]] = index.mul %[[REMU]], %[[C16:.*]] + // CHECK-DAG: %[[REMU2:.*]] = index.remu %[[MUL]], %[[C32:.*]] + // CHECK-DAG: %[[SUB:.*]] = arith.subi %[[C8:.*]], %[[REMU2]] : index + // CHECK-DAG: %[[MAX:.*]] = arith.maxsi %[[SUB]], %[[C0:.*]] : index + // CHECK-DAG: %[[MIN:.*]] = arith.minsi %[[MAX]], %[[C16:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MIN]] : vector<16xi1> + %constant_mask = vector.constant_mask [8] {layout_result_0 = #xegpu.layout<sg_layout = [2], sg_data = [16]>} : vector<32xi1> + gpu.return + } + + // CHECK-LABEL: vector_mask_2D + gpu.func @vector_mask_2D() { + // CHECK-DAG: %[[SGID:.*]] = gpu.subgroup_id : index + // CHECK-DAG: %[[SGIDX:.*]] = index.remu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY_TMP:.*]] = index.divu %[[SGID]], %[[C4:.*]] + // CHECK-DAG: %[[SGIDY:.*]] = index.remu %[[SGIDY_TMP]], %[[C8:.*]] + // CHECK-DAG: %[[ROW:.*]] = index.mul %[[SGIDY]], %[[C32:.*]] + // CHECK-DAG: %[[COL:.*]] = index.mul %[[SGIDX]], %[[C32:.*]] + // CHECK-DAG: %[[MODROW:.*]] = index.remu %[[ROW]], %[[C256:.*]] + // CHECK-DAG: %[[MODCOL:.*]] = index.remu %[[COL]], %[[C128:.*]] + // CHECK-DAG: %[[SUBROW:.*]] = arith.subi %[[C16:.*]], %[[MODROW]] : index + // CHECK-DAG: %[[MAXROW:.*]] = arith.maxsi %[[SUBROW]], %[[C4:.*]] : index + // CHECK-DAG: %[[MINROW:.*]] = arith.minsi %[[MAXROW]], %[[C32:.*]] : index + // CHECK-DAG: %[[SUBCOL:.*]] = arith.subi %[[C16:.*]], %[[MODCOL]] : index + // CHECK-DAG: %[[MAXCOL:.*]] = arith.maxsi %[[SUBCOL]], %[[C7:.*]] : index + // CHECK-DAG: %[[MINCOL:.*]] = arith.minsi %[[MAXCOL]], %[[C32:.*]] : index + // CHECK-DAG: %[[MASK:.*]] = vector.create_mask %[[MINROW]], %[[MINCOL]] : vector<32x32xi1> + %constant_mask = vector.constant_mask [16, 16] {layout_result_0 = #xegpu.layout<sg_layout = [8, 4], sg_data = [32, 32]>} : vector<256x128xi1> + gpu.return + } + // CHECK-LABEL: distribute_load_slice_attr gpu.func @distribute_load_slice_attr() { %2 = memref.alloca() {alignment = 1024} : memref<4096xf32> |
