diff options
| -rw-r--r-- | mlir/include/mlir/IR/PatternMatch.h | 10 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/DialectConversion.h | 21 | ||||
| -rw-r--r-- | mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp | 15 |
3 files changed, 34 insertions, 12 deletions
diff --git a/mlir/include/mlir/IR/PatternMatch.h b/mlir/include/mlir/IR/PatternMatch.h index 57e73c1d8c7c..b7291653b70b 100644 --- a/mlir/include/mlir/IR/PatternMatch.h +++ b/mlir/include/mlir/IR/PatternMatch.h @@ -633,13 +633,13 @@ public: /// Find uses of `from` and replace them with `to`. Also notify the listener /// about every in-place op modification (for every use that was replaced). - void replaceAllUsesWith(Value from, Value to) { + virtual void replaceAllUsesWith(Value from, Value to) { for (OpOperand &operand : llvm::make_early_inc_range(from.getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); } } - void replaceAllUsesWith(Block *from, Block *to) { + virtual void replaceAllUsesWith(Block *from, Block *to) { for (BlockOperand &operand : llvm::make_early_inc_range(from->getUses())) { Operation *op = operand.getOwner(); modifyOpInPlace(op, [&]() { operand.set(to); }); @@ -665,9 +665,9 @@ public: /// true. Also notify the listener about every in-place op modification (for /// every use that was replaced). The optional `allUsesReplaced` flag is set /// to "true" if all uses were replaced. - void replaceUsesWithIf(Value from, Value to, - function_ref<bool(OpOperand &)> functor, - bool *allUsesReplaced = nullptr); + virtual void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr); void replaceUsesWithIf(ValueRange from, ValueRange to, function_ref<bool(OpOperand &)> functor, bool *allUsesReplaced = nullptr); diff --git a/mlir/include/mlir/Transforms/DialectConversion.h b/mlir/include/mlir/Transforms/DialectConversion.h index 220431e6ee2f..9341da19905a 100644 --- a/mlir/include/mlir/Transforms/DialectConversion.h +++ b/mlir/include/mlir/Transforms/DialectConversion.h @@ -784,6 +784,27 @@ public: /// function supports both 1:1 and 1:N replacements. void replaceUsesOfBlockArgument(BlockArgument from, ValueRange to); + /// Replace all the uses of the value `from` with `to`. + /// TODO: Currently not supported in a dialect conversion. + void replaceAllUsesWith(Value from, Value to) override { + llvm::report_fatal_error("replaceAllUsesWith is not supported yet"); + } + + /// Replace all the uses of the block `from` with `to`. + /// TODO: Currently not supported in a dialect conversion. + void replaceAllUsesWith(Block *from, Block *to) override { + llvm::report_fatal_error("replaceAllUsesWith is not supported yet"); + } + + /// Replace all the uses of the value `from` with `to` if the `functor` + /// returns "true". + /// TODO: Currently not supported in a dialect conversion. + void replaceUsesWithIf(Value from, Value to, + function_ref<bool(OpOperand &)> functor, + bool *allUsesReplaced = nullptr) override { + llvm::report_fatal_error("replaceUsesWithIf is not supported yet"); + } + /// Return the converted value of 'key' with a type defined by the type /// converter of the currently executing pattern. Return nullptr in the case /// of failure, the remapped value otherwise. diff --git a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp index 34f372af1e4b..c90301661142 100644 --- a/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp +++ b/mlir/lib/Conversion/SCFToOpenMP/SCFToOpenMP.cpp @@ -22,7 +22,7 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" -#include "mlir/Transforms/DialectConversion.h" +#include "mlir/Transforms/WalkPatternRewriteDriver.h" namespace mlir { #define GEN_PASS_DEF_CONVERTSCFTOOPENMPPASS @@ -538,15 +538,16 @@ struct ParallelOpLowering : public OpRewritePattern<scf::ParallelOp> { /// Applies the conversion patterns in the given function. static LogicalResult applyPatterns(ModuleOp module, unsigned numThreads) { - ConversionTarget target(*module.getContext()); - target.addIllegalOp<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(); - target.addLegalDialect<omp::OpenMPDialect, LLVM::LLVMDialect, - memref::MemRefDialect>(); - RewritePatternSet patterns(module.getContext()); patterns.add<ParallelOpLowering>(module.getContext(), numThreads); FrozenRewritePatternSet frozen(std::move(patterns)); - return applyPartialConversion(module, target, frozen); + walkAndApplyPatterns(module, frozen); + auto status = module.walk([](Operation *op) { + if (isa<scf::ReduceOp, scf::ReduceReturnOp, scf::ParallelOp>(op)) + return WalkResult::interrupt(); + return WalkResult::advance(); + }); + return failure(status.wasInterrupted()); } /// A pass converting SCF operations to OpenMP operations. |
