summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Test/TestPatterns.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/lib/Dialect/Test/TestPatterns.cpp')
-rw-r--r--mlir/test/lib/Dialect/Test/TestPatterns.cpp9
1 files changed, 8 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Test/TestPatterns.cpp b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
index ff958d9a3d2b..657dfd2bac6e 100644
--- a/mlir/test/lib/Dialect/Test/TestPatterns.cpp
+++ b/mlir/test/lib/Dialect/Test/TestPatterns.cpp
@@ -1177,8 +1177,8 @@ struct TestNonRootReplacement : public RewritePattern {
auto illegalOp = ILLegalOpF::create(rewriter, op->getLoc(), resultType);
auto legalOp = LegalOpB::create(rewriter, op->getLoc(), resultType);
- rewriter.replaceOp(illegalOp, legalOp);
rewriter.replaceOp(op, illegalOp);
+ rewriter.replaceOp(illegalOp, legalOp);
return success();
}
};
@@ -1362,6 +1362,7 @@ public:
// Helper function that replaces the given op with a new op of the given
// name and doubles each result (1 -> 2 replacement of each result).
auto replaceWithDoubleResults = [&](Operation *op, StringRef name) {
+ rewriter.setInsertionPointAfter(op);
SmallVector<Type> types;
for (Type t : op->getResultTypes()) {
types.push_back(t);
@@ -1560,6 +1561,7 @@ struct TestLegalizePatternDriver
if (mode == ConversionMode::Partial) {
DenseSet<Operation *> unlegalizedOps;
ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
DumpNotifications dumpNotifications;
config.listener = &dumpNotifications;
config.unlegalizedOps = &unlegalizedOps;
@@ -1582,6 +1584,7 @@ struct TestLegalizePatternDriver
});
ConversionConfig config;
+ config.allowPatternRollback = allowPatternRollback;
DumpNotifications dumpNotifications;
config.foldingMode = foldingMode;
config.listener = &dumpNotifications;
@@ -1599,6 +1602,7 @@ struct TestLegalizePatternDriver
DenseSet<Operation *> legalizedOps;
ConversionConfig config;
config.foldingMode = foldingMode;
+ config.allowPatternRollback = allowPatternRollback;
config.legalizableOps = &legalizedOps;
if (failed(applyAnalysisConversion(getOperation(), target,
std::move(patterns), config)))
@@ -1634,6 +1638,9 @@ struct TestLegalizePatternDriver
"after-patterns",
"Only attempt to fold not legal operations "
"after applying patterns"))};
+ Option<bool> allowPatternRollback{*this, "allow-pattern-rollback",
+ llvm::cl::desc("Allow pattern rollback"),
+ llvm::cl::init(true)};
};
} // namespace