diff options
| author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:39:43 +0900 |
|---|---|---|
| committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:39:43 +0900 |
| commit | c36c84047e92587931e74aea1b3d91342617400b (patch) | |
| tree | 3d25b78796205b1f3f1ee5f9c55da298f6449ce8 /mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | |
| parent | 122393694892e7a718e8c612b5650388075e2833 (diff) | |
| parent | bdcf47e4bcb92889665825654bb80a8bbe30379e (diff) | |
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/condopusers/chapuni/cov/single/condop
Conflicts:
clang/lib/CodeGen/CoverageMappingGen.cpp
Diffstat (limited to 'mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp')
| -rw-r--r-- | mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp | 94 |
1 files changed, 56 insertions, 38 deletions
diff --git a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp index 61767f3b21c9..12c65a72babc 100644 --- a/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp +++ b/mlir/lib/Dialect/ArmSME/Transforms/VectorLegalization.cpp @@ -17,7 +17,7 @@ #include "mlir/Dialect/ArmSME/Transforms/Passes.h" #include "mlir/Dialect/ArmSME/Utils/Utils.h" #include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/Func/Transforms/OneToNFuncConversions.h" +#include "mlir/Dialect/Func/Transforms/FuncConversions.h" #include "mlir/Dialect/Index/IR/IndexDialect.h" #include "mlir/Dialect/Index/IR/IndexOps.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" @@ -25,7 +25,8 @@ #include "mlir/Dialect/SCF/Transforms/Patterns.h" #include "mlir/Dialect/Utils/IndexingUtils.h" #include "mlir/Dialect/Vector/Utils/VectorUtils.h" -#include "mlir/Transforms/OneToNTypeConversion.h" +#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/GreedyPatternRewriteDriver.h" #define DEBUG_TYPE "arm-sme-vector-legalization" @@ -172,12 +173,12 @@ int getNumberOfSMETilesForVectorType(VectorType type) { /// Legalize `arith.constant dense<value>` splat operations to fit within SME /// tiles by decomposing them into tile-sized operations. struct LegalizeArithConstantOpsByDecomposition - : public OneToNOpConversionPattern<arith::ConstantOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<arith::ConstantOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult matchAndRewrite(arith::ConstantOp constantOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + ConversionPatternRewriter &rewriter) const override { auto vectorType = dyn_cast<VectorType>(constantOp.getType()); auto denseAttr = dyn_cast<DenseElementsAttr>(constantOp.getValueAttr()); if (!vectorType || !denseAttr || !denseAttr.isSplat()) @@ -191,8 +192,8 @@ struct LegalizeArithConstantOpsByDecomposition auto tileCount = getNumberOfSMETilesForVectorType(vectorType); auto tileSplat = rewriter.create<arith::ConstantOp>( constantOp.getLoc(), denseAttr.resizeSplat(smeTileType)); - rewriter.replaceOp(constantOp, SmallVector<Value>(tileCount, tileSplat), - adaptor.getResultMapping()); + SmallVector<Value> repl(tileCount, tileSplat); + rewriter.replaceOpWithMultiple(constantOp, {repl}); return success(); } @@ -201,12 +202,13 @@ struct LegalizeArithConstantOpsByDecomposition /// Legalize `vector.outerproduct` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern<vector::OuterProductOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::OuterProductOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::OuterProductOp outerProductOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::OuterProductOp outerProductOp, + OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = outerProductOp.getResultVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(outerProductOp, @@ -219,6 +221,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition auto maskOp = outerProductOp.getMaskingOp(); mask = maskOp.getMask(); rootOp = maskOp; + rewriter.setInsertionPoint(rootOp); } if (!isSupportedMaskOp(mask)) @@ -248,7 +251,7 @@ struct LegalizeVectorOuterProductOpsByDecomposition resultSMETiles.push_back(maskedOuterProduct->getResult(0)); } - rewriter.replaceOp(rootOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(rootOp, {resultSMETiles}); return success(); } }; @@ -259,12 +262,12 @@ struct LegalizeVectorOuterProductOpsByDecomposition // (invalid). This pattern matches on `vector.mask` then calls into the // `vector.outerproduct` pattern to work around this issue. struct LegalizeMaskedVectorOuterProductOpsByDecomposition - : public OneToNOpConversionPattern<vector::MaskOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::MaskOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::MaskOp maskOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::MaskOp maskOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (auto outerProductOp = llvm::dyn_cast_or_null<vector::OuterProductOp>( maskOp.getMaskableOp())) { LegalizeVectorOuterProductOpsByDecomposition pattern(*getTypeConverter(), @@ -279,12 +282,12 @@ struct LegalizeMaskedVectorOuterProductOpsByDecomposition /// Legalize `vector.transfer_read` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferReadOpsByDecomposition - : public OneToNOpConversionPattern<vector::TransferReadOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferReadOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferReadOp readOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferReadOp readOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = readOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(readOp, @@ -319,7 +322,7 @@ struct LegalizeTransferReadOpsByDecomposition resultSMETiles.push_back(smeRead); } - rewriter.replaceOp(readOp, resultSMETiles, adaptor.getResultMapping()); + rewriter.replaceOpWithMultiple(readOp, {resultSMETiles}); return success(); } }; @@ -327,12 +330,12 @@ struct LegalizeTransferReadOpsByDecomposition /// Legalize `vector.transfer_write` operations to fit within SME tiles by /// decomposing them into tile-sized operations. struct LegalizeTransferWriteOpsByDecomposition - : public OneToNOpConversionPattern<vector::TransferWriteOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferWriteOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { auto vectorType = writeOp.getVectorType(); if (!isMultipleOfSMETileVectorType(vectorType)) return rewriter.notifyMatchFailure(writeOp, @@ -409,12 +412,12 @@ struct LegalizeTransferWriteOpsByDecomposition /// } /// ``` struct LegalizeMultiTileTransferWriteAsStoreLoop - : public OneToNOpConversionPattern<vector::TransferWriteOp> { - using OneToNOpConversionPattern::OneToNOpConversionPattern; + : public OpConversionPattern<vector::TransferWriteOp> { + using OpConversionPattern::OpConversionPattern; LogicalResult - matchAndRewrite(vector::TransferWriteOp writeOp, OpAdaptor adaptor, - OneToNPatternRewriter &rewriter) const override { + matchAndRewrite(vector::TransferWriteOp writeOp, OneToNOpAdaptor adaptor, + ConversionPatternRewriter &rewriter) const override { if (writeOp.hasPureTensorSemantics()) return rewriter.notifyMatchFailure( writeOp, "TODO: tensor semantics are unsupported"); @@ -936,10 +939,16 @@ struct VectorLegalizationPass return success(); }); - patterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, - LiftIllegalVectorTransposeToMemory, - ConvertIllegalShapeCastOpsToTransposes, - LowerIllegalTransposeStoreViaZA>(context); + // Apply preprocessing patterns. + RewritePatternSet rewritePatterns(context); + rewritePatterns.add<FoldExtractFromVectorOfSMELikeCreateMasks, + LiftIllegalVectorTransposeToMemory, + ConvertIllegalShapeCastOpsToTransposes, + LowerIllegalTransposeStoreViaZA>(context); + if (failed( + applyPatternsGreedily(getOperation(), std::move(rewritePatterns)))) + return signalPassFailure(); + // Note: These two patterns are added with a high benefit to ensure: // - Masked outer products are handled before unmasked ones // - Multi-tile writes are lowered as a store loop (if possible) @@ -950,11 +959,20 @@ struct VectorLegalizationPass LegalizeVectorOuterProductOpsByDecomposition, LegalizeTransferReadOpsByDecomposition, LegalizeTransferWriteOpsByDecomposition>(converter, context); - populateFuncTypeConversionPatterns(converter, patterns); - scf::populateSCFStructuralOneToNTypeConversions(converter, patterns); - - if (failed(applyPartialOneToNConversion(getOperation(), converter, - std::move(patterns)))) + populateFunctionOpInterfaceTypeConversionPattern<func::FuncOp>(patterns, + converter); + populateCallOpTypeConversionPattern(patterns, converter); + populateReturnOpTypeConversionPattern(patterns, converter); + scf::populateSCFStructuralTypeConversions(converter, patterns); + + ConversionTarget target(getContext()); + target.markUnknownOpDynamicallyLegal( + [&](Operation *op) { return converter.isLegal(op); }); + target.addDynamicallyLegalOp<func::FuncOp>([&](func::FuncOp op) { + return converter.isSignatureLegal(op.getFunctionType()); + }); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) return signalPassFailure(); } }; |
