diff options
| author | Matthias Springer <mspringer@nvidia.com> | 2025-03-05 18:55:34 +0100 |
|---|---|---|
| committer | Matthias Springer <mspringer@nvidia.com> | 2025-03-05 18:55:34 +0100 |
| commit | ae8fa931026b97b184d51e3c22d609a97c4c944c (patch) | |
| tree | fb5db0f5a8726991a2cd7c52e2c128b879388290 | |
| parent | c3ec3bdb9efdc71d7778e6b2ecb42c65bfc90a0a (diff) | |
listener based approachusers/matthias-springer/attribute-converter
| -rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 49 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h | 4 | ||||
| -rw-r--r-- | mlir/lib/IR/PatternMatch.cpp | 33 | ||||
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 3 |
4 files changed, 36 insertions, 53 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 012a8f1ec559..91ccda3a011f 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -761,6 +761,26 @@ private: RewriterBase(const RewriterBase &) = delete; }; +class DiscardableAttributeConverter : public RewriterBase::Listener { +public: + using DiscardableAttributeConverterFn = + std::function<LogicalResult(Operation *, Operation *)>; + + DiscardableAttributeConverter( + RewriterBase &rewriter, + ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters) + : rewriter(rewriter), + dicardableAttributeConverters(dicardableAttributeConverters) {} + +protected: + void notifyOperationErased(Operation *op) override; + + void notifyOperationReplaced(Operation *op, Operation *replacement) override; + + RewriterBase &rewriter; + ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters; +}; + //===----------------------------------------------------------------------===// // IRRewriter //===----------------------------------------------------------------------===// @@ -790,15 +810,7 @@ public: /// place. class PatternRewriter : public RewriterBase { public: - using DiscardableAttributeConverterFn = - std::function<LogicalResult(Operation *, Operation *)>; - explicit PatternRewriter(MLIRContext *ctx) : RewriterBase(ctx) {} - PatternRewriter( - MLIRContext *ctx, - ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters) - : RewriterBase(ctx), - dicardableAttributeConverters(dicardableAttributeConverters) {} using RewriterBase::RewriterBase; /// A hook used to indicate if the pattern rewriter can recover from failure @@ -806,27 +818,6 @@ public: /// rewriter supports rollback, it may progress smoothly even if IR was /// changed during the rewrite. virtual bool canRecoverFromRewriteFailure() const { return false; } - - /// Erase an operation that is known to have no uses. If this pattern - /// rewriter has attribute converters, asserts the op (and its nested ops) - /// has no discardable attributes. - void eraseOp(Operation *op) override; - - /// Replace the results of the given (original) operation with the specified - /// new op (replacement). The result types of the two ops must match. The - /// original op is erased. - /// - /// If the original op has discardable attributes, try to run an attribute - /// converter. - void replaceOp(Operation *op, Operation *newOp) override; - using RewriterBase::replaceOp; - -protected: - ArrayRef<DiscardableAttributeConverterFn> dicardableAttributeConverters; - - bool hasAttributeConverter() const { - return !dicardableAttributeConverters.empty(); - } }; } // namespace mlir diff --git a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h index 2e3aed902802..110b4f64856e 100644 --- a/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h +++ b/mlir/include/mlir/Transforms/GreedyPatternRewriteDriver.h @@ -14,7 +14,6 @@ #ifndef MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ #define MLIR_TRANSFORMS_GREEDYPATTERNREWRITEDRIVER_H_ -#include "mlir/IR/PatternMatch.h" #include "mlir/Rewrite/FrozenRewritePatternSet.h" namespace mlir { @@ -99,9 +98,6 @@ public: /// If set to "true", constants are CSE'd (even across multiple regions that /// are in a parent-ancestor relationship). bool cseConstants = true; - - SmallVector<PatternRewriter::DiscardableAttributeConverterFn> - dicardableAttributeConverters; }; //===----------------------------------------------------------------------===// diff --git a/mlir/lib/IR/PatternMatch.cpp b/mlir/lib/IR/PatternMatch.cpp index 6826c3add18b..9507e7af5a1a 100644 --- a/mlir/lib/IR/PatternMatch.cpp +++ b/mlir/lib/IR/PatternMatch.cpp @@ -425,37 +425,34 @@ void RewriterBase::moveOpAfter(Operation *op, Block *block, } //===----------------------------------------------------------------------===// -// PatternRewriter +// DiscardableAttributeConverter //===----------------------------------------------------------------------===// -void PatternRewriter::eraseOp(Operation *op) { - if (hasAttributeConverter()) { - op->walk([](Operation *op) { - assert(op->getDiscardableAttrs().empty() && - "attempting to drop discardable attribute"); - }); - } - RewriterBase::eraseOp(op); +void DiscardableAttributeConverter::notifyOperationErased(Operation *op) { + op->walk([](Operation *op) { + assert(op->getDiscardableAttrs().empty() && + "attempting to drop discardable attribute"); + }); } -void PatternRewriter::replaceOp(Operation *oldOp, Operation *newOp) { - if (hasAttributeConverter() && !oldOp->getDiscardableAttrs().empty()) { - startOpModification(oldOp); - startOpModification(newOp); +void DiscardableAttributeConverter::notifyOperationReplaced(Operation *oldOp, + Operation *newOp) { + if (!oldOp->getDiscardableAttrs().empty()) { + rewriter.startOpModification(oldOp); + rewriter.startOpModification(newOp); bool success = false; for (DiscardableAttributeConverterFn fn : llvm::reverse(dicardableAttributeConverters)) { if (succeeded(fn(oldOp, newOp))) { success = true; - finalizeOpModification(oldOp); - finalizeOpModification(newOp); + rewriter.finalizeOpModification(oldOp); + rewriter.finalizeOpModification(newOp); break; } } if (!success) { - cancelOpModification(oldOp); - cancelOpModification(newOp); + rewriter.cancelOpModification(oldOp); + rewriter.cancelOpModification(newOp); } } - RewriterBase::replaceOp(oldOp, newOp); } diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 56311ed560b3..fe84c6130064 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -411,8 +411,7 @@ private: GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( MLIRContext *ctx, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : rewriter(ctx, config.dicardableAttributeConverters), config(config), - matcher(patterns) + : rewriter(ctx), config(config), matcher(patterns) #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // clang-format off , expensiveChecks( |
