diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 21c618ab633f..727c4fc7c639 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -17,6 +17,7 @@ namespace mlir { #define GEN_PASS_DEF_SCFBUFFERIZE +#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir @@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> { return signalPassFailure(); }; }; + +struct SCFLoopBufferizationPreprocessingPass + : public impl::SCFLoopBufferizationPreprocessingBase< + SCFLoopBufferizationPreprocessingPass> { + void runOnOperation() override { + OpBuilder builder(getOperation()->getContext()); + getOperation()->walk([&](scf::YieldOp yieldOp) { + builder.setInsertionPoint(yieldOp); + // TODO: Support scf.while. + auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()); + if (!forOp) + return WalkResult::skip(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = dyn_cast<TensorType>(operand.get().getType()); + if (!tensorType) + continue; + auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()]; + Value materialized = + builder + .create<bufferization::MaterializeInDestinationOp>( + yieldOp.getLoc(), tensorType, operand.get(), bbArg) + .getResult(); + operand.set(materialized); + } + return WalkResult::advance(); + }); + } +}; } // namespace std::unique_ptr<Pass> mlir::createSCFBufferizePass() { return std::make_unique<SCFBufferizePass>(); } + +std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() { + return std::make_unique<SCFLoopBufferizationPreprocessingPass>(); +} |
