diff options
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
| -rw-r--r-- | mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp | 283 |
1 files changed, 262 insertions, 21 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp index 597cb29ce911..99e82827cdef 100644 --- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp +++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp @@ -14,10 +14,12 @@ #include "mlir/Config/mlir-config.h" #include "mlir/IR/Action.h" +#include "mlir/IR/Iterators.h" #include "mlir/IR/Matchers.h" #include "mlir/IR/Verifier.h" #include "mlir/Interfaces/SideEffectInterfaces.h" #include "mlir/Rewrite/PatternApplicator.h" +#include "mlir/Transforms/DialectConversion.h" #include "mlir/Transforms/FoldUtils.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/BitVector.h" @@ -321,7 +323,7 @@ private: /// to the worklist in the beginning. class GreedyPatternRewriteDriver : public RewriterBase::Listener { protected: - explicit GreedyPatternRewriteDriver(MLIRContext *ctx, + explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config); @@ -329,7 +331,7 @@ protected: void addSingleOpToWorklist(Operation *op); /// Add the given operation and its ancestors to the worklist. - void addToWorklist(Operation *op); + virtual void addToWorklist(Operation *op); /// Notify the driver that the specified operation may have been modified /// in-place. The operation is added to the worklist. @@ -356,7 +358,7 @@ protected: /// The pattern rewriter that is used for making IR modifications and is /// passed to rewrite patterns. - PatternRewriter rewriter; + PatternRewriter &rewriter; /// The worklist for this transformation keeps track of the operations that /// need to be (re)visited. @@ -375,6 +377,11 @@ protected: /// `config.strictMode` is GreedyRewriteStrictness::AnyOp. llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps; +#ifndef NDEBUG + /// A logger used to emit information during the application process. + llvm::ScopedPrinter logger{llvm::dbgs()}; +#endif + private: /// Look over the provided operands for any defining operations that should /// be re-added to the worklist. This function should be called when an @@ -394,11 +401,6 @@ private: notifyMatchFailure(Location loc, function_ref<void(Diagnostic &)> reasonCallback) override; -#ifndef NDEBUG - /// A logger used to emit information during the application process. - llvm::ScopedPrinter logger{llvm::dbgs()}; -#endif - /// The low-level pattern applicator. PatternApplicator matcher; @@ -409,9 +411,9 @@ private: } // namespace GreedyPatternRewriteDriver::GreedyPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config) - : rewriter(ctx), config(config), matcher(patterns) + : rewriter(rewriter), config(config), matcher(patterns) #if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // clang-format off , expensiveChecks( @@ -476,7 +478,7 @@ bool GreedyPatternRewriteDriver::processWorklist() { }); // If the operation is trivially dead - remove it. - if (isOpTriviallyDead(op)) { + if (config.enableOperationDce && isOpTriviallyDead(op)) { rewriter.eraseOp(op); changed = true; @@ -780,7 +782,7 @@ namespace { /// This driver simplfies all ops in a region. class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver { public: - explicit RegionPatternRewriteDriver(MLIRContext *ctx, + explicit RegionPatternRewriteDriver(PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, Region ®ions); @@ -796,9 +798,9 @@ private: } // namespace RegionPatternRewriteDriver::RegionPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, Region ®ion) - : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) { + : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) { // Populate strict mode ops. if (config.strictMode != GreedyRewriteStrictness::AnyOp) { region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); }); @@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region ®ion, #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Start the pattern driver. - RegionPatternRewriteDriver driver(region.getContext(), patterns, config, - region); + PatternRewriter rewriter(region.getContext()); + RegionPatternRewriteDriver driver(rewriter, patterns, config, region); LogicalResult converged = std::move(driver).simplify(changed); LLVM_DEBUG(if (failed(converged)) { llvm::dbgs() << "The pattern rewrite did not converge after scanning " @@ -928,7 +930,7 @@ namespace { class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver { public: explicit MultiOpPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr); @@ -950,10 +952,10 @@ private: } // namespace MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver( - MLIRContext *ctx, const FrozenRewritePatternSet &patterns, + PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns, const GreedyRewriteConfig &config, ArrayRef<Operation *> ops, llvm::SmallDenseSet<Operation *, 4> *survivingOps) - : GreedyPatternRewriteDriver(ctx, patterns, config), + : GreedyPatternRewriteDriver(rewriter, patterns, config), survivingOps(survivingOps) { if (config.strictMode != GreedyRewriteStrictness::AnyOp) strictModeFilteredOps.insert(ops.begin(), ops.end()); @@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold( #endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS // Start the pattern driver. + PatternRewriter rewriter(ops.front()->getContext()); llvm::SmallDenseSet<Operation *, 4> surviving; - MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns, - config, ops, + MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops, allErased ? &surviving : nullptr); LogicalResult converged = std::move(driver).simplify(ops, changed); if (allErased) @@ -1053,3 +1055,242 @@ LogicalResult mlir::applyOpPatternsAndFold( }); return converged; } + +//===----------------------------------------------------------------------===// +// One-Shot Dialect Conversion Infrastructure +//===----------------------------------------------------------------------===// + +namespace { +/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter +/// immediately materializes all IR changes. It derives from +/// `ConversionPatternRewriter` so that the existing conversion patterns can +/// be used with the One-Shot Dialect Conversion. +class OneShotConversionPatternRewriter : public ConversionPatternRewriter { +public: + OneShotConversionPatternRewriter(MLIRContext *ctx) + : ConversionPatternRewriter(ctx) {} + + bool canRecoverFromRewriteFailure() const override { return false; } + + void replaceOp(Operation *op, ValueRange newValues) override; + + void replaceOp(Operation *op, Operation *newOp) override { + replaceOp(op, newOp->getResults()); + } + + void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); } + + void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); } + + void inlineBlockBefore(Block *source, Block *dest, Block::iterator before, + ValueRange argValues = std::nullopt) override { + PatternRewriter::inlineBlockBefore(source, dest, before, argValues); + } + using PatternRewriter::inlineBlockBefore; + + void startOpModification(Operation *op) override { + PatternRewriter::startOpModification(op); + } + + void finalizeOpModification(Operation *op) override { + PatternRewriter::finalizeOpModification(op); + } + + void cancelOpModification(Operation *op) override { + PatternRewriter::cancelOpModification(op); + } + + void setCurrentTypeConverter(const TypeConverter *converter) override { + typeConverter = converter; + } + + const TypeConverter *getCurrentTypeConverter() const override { + return typeConverter; + } + + LogicalResult getAdapterOperands(StringRef valueDiagTag, + std::optional<Location> inputLoc, + ValueRange values, + SmallVector<Value> &remapped) override; + +private: + /// Build an unrealized_conversion_cast op or look it up in the cache. + Value buildUnrealizedConversionCast(Location loc, Type type, Value value); + + /// The current type converter. + const TypeConverter *typeConverter; + + /// A cache for unrealized_conversion_casts. To ensure that identical casts + /// are not built multiple times. + DenseMap<std::pair<Value, Type>, Value> castCache; +}; + +void OneShotConversionPatternRewriter::replaceOp(Operation *op, + ValueRange newValues) { + assert(op->getNumResults() == newValues.size()); + for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) { + if (orig.getType() != repl.getType()) { + // Type mismatch: insert unrealized_conversion cast. + replaceAllUsesWith(orig, buildUnrealizedConversionCast( + op->getLoc(), orig.getType(), repl)); + } else { + // Same type: use replacement value directly. + replaceAllUsesWith(orig, repl); + } + } + eraseOp(op); +} + +Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast( + Location loc, Type type, Value value) { + auto it = castCache.find(std::make_pair(value, type)); + if (it != castCache.end()) + return it->second; + + // Insert cast at the beginning of the block (for block arguments) or right + // after the defining op. + OpBuilder::InsertionGuard g(*this); + Block *insertBlock = value.getParentBlock(); + Block::iterator insertPt = insertBlock->begin(); + if (OpResult inputRes = dyn_cast<OpResult>(value)) + insertPt = ++inputRes.getOwner()->getIterator(); + setInsertionPoint(insertBlock, insertPt); + auto castOp = create<UnrealizedConversionCastOp>(loc, type, value); + castCache[std::make_pair(value, type)] = castOp.getOutputs()[0]; + return castOp.getOutputs()[0]; +} + +class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver { +public: + ConversionPatternRewriteDriver(PatternRewriter &rewriter, + const FrozenRewritePatternSet &patterns, + const GreedyRewriteConfig &config, + const ConversionTarget &target) + : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) { + } + + /// Populate the worklist with all illegal ops and start the conversion + /// process. + LogicalResult convert(Operation *op) &&; + +protected: + void addToWorklist(Operation *op) override; + + /// Notify the driver that the specified operation was removed. Update the + /// worklist as needed: The operation and its children are removed from the + /// worklist. + void notifyOperationErased(Operation *op) override; + +private: + const ConversionTarget ⌖ +}; +} // namespace + +LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && { + op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) { + auto legalityInfo = target.isLegal(op); + if (!legalityInfo) { + addSingleOpToWorklist(op); + return WalkResult::advance(); + } + if (legalityInfo->isRecursivelyLegal) { + // Don't check this operation's children for conversion if the + // operation is recursively legal. + return WalkResult::skip(); + } + return WalkResult::advance(); + }); + + // Reverse the list so our pop-back loop processes them in-order. + // TODO: newly enqueued ops must also be reversed + worklist.reverse(); + + processWorklist(); + + return success(); +} + +void ConversionPatternRewriteDriver::addToWorklist(Operation *op) { + if (!target.isLegal(op)) + addSingleOpToWorklist(op); +} + +// TODO: Refactor. This is the same as +// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to +// the worklist. +void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) { + LLVM_DEBUG({ + logger.startLine() << "** Erase : '" << op->getName() << "'(" << op + << ")\n"; + }); + +#ifndef NDEBUG + // Only ops that are within the configured scope are added to the worklist of + // the greedy pattern rewriter. Moreover, the parent op of the scope region is + // the part of the IR that is taken into account for the "expensive checks". + // A greedy pattern rewrite is not allowed to erase the parent op of the scope + // region, as that would break the worklist handling and the expensive checks. + if (config.scope && config.scope->getParentOp() == op) + llvm_unreachable( + "scope region must not be erased during greedy pattern rewrite"); +#endif // NDEBUG + + if (config.listener) + config.listener->notifyOperationErased(op); + + worklist.remove(op); + + if (config.strictMode != GreedyRewriteStrictness::AnyOp) + strictModeFilteredOps.erase(op); +} + +/// Populate the converted operands in `remapped`. (Based on the currently set +/// type converter.) +LogicalResult OneShotConversionPatternRewriter::getAdapterOperands( + StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values, + SmallVector<Value> &remapped) { + // TODO: Refactor. This is mostly copied from the current dialect conversion. + for (Value v : values) { + // Skip all unrealized_conversion_casts in the chain of defining ops. + Value vBase = v; + while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>()) + vBase = castOp.getInputs()[0]; + + if (!getCurrentTypeConverter()) { + // No type converter set. Just replicate what the current type conversion + // is doing. + // TODO: We may have to distinguish between newly-inserted an + // pre-existing unrealized_conversion_casts. + remapped.push_back(vBase); + continue; + } + + Type desiredType; + SmallVector<Type, 1> legalTypes; + if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes))) + return failure(); + assert(legalTypes.size() == 1 && "1:N conversion not supported yet"); + desiredType = legalTypes.front(); + if (desiredType == vBase.getType()) { + // Type already matches. No need to convert anything. + remapped.push_back(vBase); + continue; + } + + Location operandLoc = inputLoc ? *inputLoc : v.getLoc(); + remapped.push_back( + buildUnrealizedConversionCast(operandLoc, desiredType, vBase)); + } + return success(); +} + +LogicalResult +mlir::applyPartialOneShotConversion(Operation *op, + const ConversionTarget &target, + const FrozenRewritePatternSet &patterns) { + GreedyRewriteConfig config; + config.enableOperationDce = false; + OneShotConversionPatternRewriter rewriter(op->getContext()); + ConversionPatternRewriteDriver driver(rewriter, patterns, config, target); + return std::move(driver).convert(op); +} |
