summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/IR/PatternMatch.h10
-rw-r--r--mlir/include/mlir/Transforms/DialectConversion.h21
-rw-r--r--mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp15
3 files changed, 34 insertions, 12 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h
index 57e73c1d8c7c..b7291653b70b 100644
--- a/mlir/include/mlir/IR/PatternMatch.h
+++ b/mlir/include/mlir/IR/PatternMatch.h
@@ -633,13 +633,13 @@ public:
/// Find uses of `from` and replace them with `to`. Also notify the listener
/// about every in-place op modification (for every use that was replaced).
- void replaceAllUsesWith(Value from, Value to) {
+ virtual void replaceAllUsesWith(Value from, Value to) {
for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
}
}
- void replaceAllUsesWith(Block *from, Block *to) {
+ virtual void replaceAllUsesWith(Block *from, Block *to) {
for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) {
Operation *op = operand.getOwner();
modifyOpInPlace(op, [&]() { operand.set(to); });
@@ -665,9 +665,9 @@ public:
/// true. Also notify the listener about every in-place op modification (for
/// every use that was replaced). The optional `allUsesReplaced` flag is set
/// to "true" if all uses were replaced.
- void replaceUsesWithIf(Value from, Value to,
- function_ref<bool(OpOperand &)> functor,
- bool *allUsesReplaced = nullptr);
+ virtual void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr);
void replaceUsesWithIf(ValueRange from, ValueRange to,
function_ref<bool(OpOperand &)> functor,
bool *allUsesReplaced = nullptr);
diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h
index 220431e6ee2f..9341da19905a 100644
--- a/mlir/include/mlir/Transforms/DialectConversion.h
+++ b/mlir/include/mlir/Transforms/DialectConversion.h
@@ -784,6 +784,27 @@ public:
/// function supports both 1:1 and 1:N replacements.
void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to);
+ /// Replace all the uses of the value `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Value from, Value to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the block `from` with `to`.
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceAllUsesWith(Block *from, Block *to) override {
+ llvm::report_fatal_error("replaceAllUsesWith is not supported yet");
+ }
+
+ /// Replace all the uses of the value `from` with `to` if the `functor`
+ /// returns "true".
+ /// TODO: Currently not supported in a dialect conversion.
+ void replaceUsesWithIf(Value from, Value to,
+ function_ref<bool(OpOperand &)> functor,
+ bool *allUsesReplaced = nullptr) override {
+ llvm::report_fatal_error("replaceUsesWithIf is not supported yet");
+ }
+
/// Return the converted value of 'key' with a type defined by the type
/// converter of the currently executing pattern. Return nullptr in the case
/// of failure, the remapped value otherwise.
diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
index 34f372af1e4b..c90301661142 100644
--- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
+++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp
@@ -22,7 +22,7 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/SymbolTable.h"
#include "mlir/Pass/Pass.h"
-#include "mlir/Transforms/DialectConversion.h"
+#include "mlir/Transforms/WalkPatternRewriteDriver.h"
namespace mlir {
#define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS
@@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> {
/// Applies the conversion patterns in the given function.
static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) {
- ConversionTarget target(*module.getContext());
- target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>();
- target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect,
- memref::MemRefDialect>();
-
RewritePatternSet patterns(module.getContext());
patterns.add<ParallelOpLowering>(module.getContext(), numThreads);
FrozenRewritePatternSet frozen(std::move(patterns));
- return applyPartialConversion(module, target, frozen);
+ walkAndApplyPatterns(module, frozen);
+ auto status = module.walk([](Operation *op) {
+ if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op))
+ return WalkResult::interrupt();
+ return WalkResult::advance();
+ });
+ return failure(status.wasInterrupted());
}
/// A pass converting SCF operations to OpenMP operations.