summaryrefslogtreecommitdiff
path: root/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp')
-rw-r--r--mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp283
1 files changed, 262 insertions, 21 deletions
diff --git a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
index 597cb29ce911..99e82827cdef 100644
--- a/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
+++ b/mlir/lib/Transforms/Utils/GreedyPatternRewriteDriver.cpp
@@ -14,10 +14,12 @@
#include "mlir/Config/mlir-config.h"
#include "mlir/IR/Action.h"
+#include "mlir/IR/Iterators.h"
#include "mlir/IR/Matchers.h"
#include "mlir/IR/Verifier.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
#include "mlir/Rewrite/PatternApplicator.h"
+#include "mlir/Transforms/DialectConversion.h"
#include "mlir/Transforms/FoldUtils.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/BitVector.h"
@@ -321,7 +323,7 @@ private:
/// to the worklist in the beginning.
class GreedyPatternRewriteDriver : public RewriterBase::Listener {
protected:
- explicit GreedyPatternRewriteDriver(MLIRContext *ctx,
+ explicit GreedyPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config);
@@ -329,7 +331,7 @@ protected:
void addSingleOpToWorklist(Operation *op);
/// Add the given operation and its ancestors to the worklist.
- void addToWorklist(Operation *op);
+ virtual void addToWorklist(Operation *op);
/// Notify the driver that the specified operation may have been modified
/// in-place. The operation is added to the worklist.
@@ -356,7 +358,7 @@ protected:
/// The pattern rewriter that is used for making IR modifications and is
/// passed to rewrite patterns.
- PatternRewriter rewriter;
+ PatternRewriter &rewriter;
/// The worklist for this transformation keeps track of the operations that
/// need to be (re)visited.
@@ -375,6 +377,11 @@ protected:
/// `config.strictMode` is GreedyRewriteStrictness::AnyOp.
llvm::SmallDenseSet<Operation *, 4> strictModeFilteredOps;
+#ifndef NDEBUG
+ /// A logger used to emit information during the application process.
+ llvm::ScopedPrinter logger{llvm::dbgs()};
+#endif
+
private:
/// Look over the provided operands for any defining operations that should
/// be re-added to the worklist. This function should be called when an
@@ -394,11 +401,6 @@ private:
notifyMatchFailure(Location loc,
function_ref<void(Diagnostic &)> reasonCallback) override;
-#ifndef NDEBUG
- /// A logger used to emit information during the application process.
- llvm::ScopedPrinter logger{llvm::dbgs()};
-#endif
-
/// The low-level pattern applicator.
PatternApplicator matcher;
@@ -409,9 +411,9 @@ private:
} // namespace
GreedyPatternRewriteDriver::GreedyPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config)
- : rewriter(ctx), config(config), matcher(patterns)
+ : rewriter(rewriter), config(config), matcher(patterns)
#if MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// clang-format off
, expensiveChecks(
@@ -476,7 +478,7 @@ bool GreedyPatternRewriteDriver::processWorklist() {
});
// If the operation is trivially dead - remove it.
- if (isOpTriviallyDead(op)) {
+ if (config.enableOperationDce && isOpTriviallyDead(op)) {
rewriter.eraseOp(op);
changed = true;
@@ -780,7 +782,7 @@ namespace {
/// This driver simplfies all ops in a region.
class RegionPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
- explicit RegionPatternRewriteDriver(MLIRContext *ctx,
+ explicit RegionPatternRewriteDriver(PatternRewriter &rewriter,
const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config,
Region &regions);
@@ -796,9 +798,9 @@ private:
} // namespace
RegionPatternRewriteDriver::RegionPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, Region &region)
- : GreedyPatternRewriteDriver(ctx, patterns, config), region(region) {
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), region(region) {
// Populate strict mode ops.
if (config.strictMode != GreedyRewriteStrictness::AnyOp) {
region.walk([&](Operation *op) { strictModeFilteredOps.insert(op); });
@@ -909,8 +911,8 @@ mlir::applyPatternsAndFoldGreedily(Region &region,
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
- RegionPatternRewriteDriver driver(region.getContext(), patterns, config,
- region);
+ PatternRewriter rewriter(region.getContext());
+ RegionPatternRewriteDriver driver(rewriter, patterns, config, region);
LogicalResult converged = std::move(driver).simplify(changed);
LLVM_DEBUG(if (failed(converged)) {
llvm::dbgs() << "The pattern rewrite did not converge after scanning "
@@ -928,7 +930,7 @@ namespace {
class MultiOpPatternRewriteDriver : public GreedyPatternRewriteDriver {
public:
explicit MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps = nullptr);
@@ -950,10 +952,10 @@ private:
} // namespace
MultiOpPatternRewriteDriver::MultiOpPatternRewriteDriver(
- MLIRContext *ctx, const FrozenRewritePatternSet &patterns,
+ PatternRewriter &rewriter, const FrozenRewritePatternSet &patterns,
const GreedyRewriteConfig &config, ArrayRef<Operation *> ops,
llvm::SmallDenseSet<Operation *, 4> *survivingOps)
- : GreedyPatternRewriteDriver(ctx, patterns, config),
+ : GreedyPatternRewriteDriver(rewriter, patterns, config),
survivingOps(survivingOps) {
if (config.strictMode != GreedyRewriteStrictness::AnyOp)
strictModeFilteredOps.insert(ops.begin(), ops.end());
@@ -1040,9 +1042,9 @@ LogicalResult mlir::applyOpPatternsAndFold(
#endif // MLIR_ENABLE_EXPENSIVE_PATTERN_API_CHECKS
// Start the pattern driver.
+ PatternRewriter rewriter(ops.front()->getContext());
llvm::SmallDenseSet<Operation *, 4> surviving;
- MultiOpPatternRewriteDriver driver(ops.front()->getContext(), patterns,
- config, ops,
+ MultiOpPatternRewriteDriver driver(rewriter, patterns, config, ops,
allErased ? &surviving : nullptr);
LogicalResult converged = std::move(driver).simplify(ops, changed);
if (allErased)
@@ -1053,3 +1055,242 @@ LogicalResult mlir::applyOpPatternsAndFold(
});
return converged;
}
+
+//===----------------------------------------------------------------------===//
+// One-Shot Dialect Conversion Infrastructure
+//===----------------------------------------------------------------------===//
+
+namespace {
+/// A conversion rewriter for the One-Shot Dialect Conversion. This rewriter
+/// immediately materializes all IR changes. It derives from
+/// `ConversionPatternRewriter` so that the existing conversion patterns can
+/// be used with the One-Shot Dialect Conversion.
+class OneShotConversionPatternRewriter : public ConversionPatternRewriter {
+public:
+ OneShotConversionPatternRewriter(MLIRContext *ctx)
+ : ConversionPatternRewriter(ctx) {}
+
+ bool canRecoverFromRewriteFailure() const override { return false; }
+
+ void replaceOp(Operation *op, ValueRange newValues) override;
+
+ void replaceOp(Operation *op, Operation *newOp) override {
+ replaceOp(op, newOp->getResults());
+ }
+
+ void eraseOp(Operation *op) override { PatternRewriter::eraseOp(op); }
+
+ void eraseBlock(Block *block) override { PatternRewriter::eraseBlock(block); }
+
+ void inlineBlockBefore(Block *source, Block *dest, Block::iterator before,
+ ValueRange argValues = std::nullopt) override {
+ PatternRewriter::inlineBlockBefore(source, dest, before, argValues);
+ }
+ using PatternRewriter::inlineBlockBefore;
+
+ void startOpModification(Operation *op) override {
+ PatternRewriter::startOpModification(op);
+ }
+
+ void finalizeOpModification(Operation *op) override {
+ PatternRewriter::finalizeOpModification(op);
+ }
+
+ void cancelOpModification(Operation *op) override {
+ PatternRewriter::cancelOpModification(op);
+ }
+
+ void setCurrentTypeConverter(const TypeConverter *converter) override {
+ typeConverter = converter;
+ }
+
+ const TypeConverter *getCurrentTypeConverter() const override {
+ return typeConverter;
+ }
+
+ LogicalResult getAdapterOperands(StringRef valueDiagTag,
+ std::optional<Location> inputLoc,
+ ValueRange values,
+ SmallVector<Value> &remapped) override;
+
+private:
+ /// Build an unrealized_conversion_cast op or look it up in the cache.
+ Value buildUnrealizedConversionCast(Location loc, Type type, Value value);
+
+ /// The current type converter.
+ const TypeConverter *typeConverter;
+
+ /// A cache for unrealized_conversion_casts. To ensure that identical casts
+ /// are not built multiple times.
+ DenseMap<std::pair<Value, Type>, Value> castCache;
+};
+
+void OneShotConversionPatternRewriter::replaceOp(Operation *op,
+ ValueRange newValues) {
+ assert(op->getNumResults() == newValues.size());
+ for (auto [orig, repl] : llvm::zip_equal(op->getResults(), newValues)) {
+ if (orig.getType() != repl.getType()) {
+ // Type mismatch: insert unrealized_conversion cast.
+ replaceAllUsesWith(orig, buildUnrealizedConversionCast(
+ op->getLoc(), orig.getType(), repl));
+ } else {
+ // Same type: use replacement value directly.
+ replaceAllUsesWith(orig, repl);
+ }
+ }
+ eraseOp(op);
+}
+
+Value OneShotConversionPatternRewriter::buildUnrealizedConversionCast(
+ Location loc, Type type, Value value) {
+ auto it = castCache.find(std::make_pair(value, type));
+ if (it != castCache.end())
+ return it->second;
+
+ // Insert cast at the beginning of the block (for block arguments) or right
+ // after the defining op.
+ OpBuilder::InsertionGuard g(*this);
+ Block *insertBlock = value.getParentBlock();
+ Block::iterator insertPt = insertBlock->begin();
+ if (OpResult inputRes = dyn_cast<OpResult>(value))
+ insertPt = ++inputRes.getOwner()->getIterator();
+ setInsertionPoint(insertBlock, insertPt);
+ auto castOp = create<UnrealizedConversionCastOp>(loc, type, value);
+ castCache[std::make_pair(value, type)] = castOp.getOutputs()[0];
+ return castOp.getOutputs()[0];
+}
+
+class ConversionPatternRewriteDriver : public GreedyPatternRewriteDriver {
+public:
+ ConversionPatternRewriteDriver(PatternRewriter &rewriter,
+ const FrozenRewritePatternSet &patterns,
+ const GreedyRewriteConfig &config,
+ const ConversionTarget &target)
+ : GreedyPatternRewriteDriver(rewriter, patterns, config), target(target) {
+ }
+
+ /// Populate the worklist with all illegal ops and start the conversion
+ /// process.
+ LogicalResult convert(Operation *op) &&;
+
+protected:
+ void addToWorklist(Operation *op) override;
+
+ /// Notify the driver that the specified operation was removed. Update the
+ /// worklist as needed: The operation and its children are removed from the
+ /// worklist.
+ void notifyOperationErased(Operation *op) override;
+
+private:
+ const ConversionTarget &target;
+};
+} // namespace
+
+LogicalResult ConversionPatternRewriteDriver::convert(Operation *op) && {
+ op->walk<WalkOrder::PreOrder, ForwardDominanceIterator<>>([&](Operation *op) {
+ auto legalityInfo = target.isLegal(op);
+ if (!legalityInfo) {
+ addSingleOpToWorklist(op);
+ return WalkResult::advance();
+ }
+ if (legalityInfo->isRecursivelyLegal) {
+ // Don't check this operation's children for conversion if the
+ // operation is recursively legal.
+ return WalkResult::skip();
+ }
+ return WalkResult::advance();
+ });
+
+ // Reverse the list so our pop-back loop processes them in-order.
+ // TODO: newly enqueued ops must also be reversed
+ worklist.reverse();
+
+ processWorklist();
+
+ return success();
+}
+
+void ConversionPatternRewriteDriver::addToWorklist(Operation *op) {
+ if (!target.isLegal(op))
+ addSingleOpToWorklist(op);
+}
+
+// TODO: Refactor. This is the same as
+// `GreedyPatternRewriteDriver::notifyOperationErased`, but does not add ops to
+// the worklist.
+void ConversionPatternRewriteDriver::notifyOperationErased(Operation *op) {
+ LLVM_DEBUG({
+ logger.startLine() << "** Erase : '" << op->getName() << "'(" << op
+ << ")\n";
+ });
+
+#ifndef NDEBUG
+ // Only ops that are within the configured scope are added to the worklist of
+ // the greedy pattern rewriter. Moreover, the parent op of the scope region is
+ // the part of the IR that is taken into account for the "expensive checks".
+ // A greedy pattern rewrite is not allowed to erase the parent op of the scope
+ // region, as that would break the worklist handling and the expensive checks.
+ if (config.scope && config.scope->getParentOp() == op)
+ llvm_unreachable(
+ "scope region must not be erased during greedy pattern rewrite");
+#endif // NDEBUG
+
+ if (config.listener)
+ config.listener->notifyOperationErased(op);
+
+ worklist.remove(op);
+
+ if (config.strictMode != GreedyRewriteStrictness::AnyOp)
+ strictModeFilteredOps.erase(op);
+}
+
+/// Populate the converted operands in `remapped`. (Based on the currently set
+/// type converter.)
+LogicalResult OneShotConversionPatternRewriter::getAdapterOperands(
+ StringRef valueDiagTag, std::optional<Location> inputLoc, ValueRange values,
+ SmallVector<Value> &remapped) {
+ // TODO: Refactor. This is mostly copied from the current dialect conversion.
+ for (Value v : values) {
+ // Skip all unrealized_conversion_casts in the chain of defining ops.
+ Value vBase = v;
+ while (auto castOp = vBase.getDefiningOp<UnrealizedConversionCastOp>())
+ vBase = castOp.getInputs()[0];
+
+ if (!getCurrentTypeConverter()) {
+ // No type converter set. Just replicate what the current type conversion
+ // is doing.
+ // TODO: We may have to distinguish between newly-inserted an
+ // pre-existing unrealized_conversion_casts.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Type desiredType;
+ SmallVector<Type, 1> legalTypes;
+ if (failed(getCurrentTypeConverter()->convertType(v.getType(), legalTypes)))
+ return failure();
+ assert(legalTypes.size() == 1 && "1:N conversion not supported yet");
+ desiredType = legalTypes.front();
+ if (desiredType == vBase.getType()) {
+ // Type already matches. No need to convert anything.
+ remapped.push_back(vBase);
+ continue;
+ }
+
+ Location operandLoc = inputLoc ? *inputLoc : v.getLoc();
+ remapped.push_back(
+ buildUnrealizedConversionCast(operandLoc, desiredType, vBase));
+ }
+ return success();
+}
+
+LogicalResult
+mlir::applyPartialOneShotConversion(Operation *op,
+ const ConversionTarget &target,
+ const FrozenRewritePatternSet &patterns) {
+ GreedyRewriteConfig config;
+ config.enableOperationDce = false;
+ OneShotConversionPatternRewriter rewriter(op->getContext());
+ ConversionPatternRewriteDriver driver(rewriter, patterns, config, target);
+ return std::move(driver).convert(op);
+}