summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
diff options
context:
space:
mode:
authorMatthias Springer <springerm@google.com>2024-04-04 03:07:43 +0000
committerMatthias Springer <springerm@google.com>2024-04-04 03:07:45 +0000
commit4c819864c5edbdb8137451c0a0b6a97240a17008 (patch)
tree3617ed9d8ff7712b6c445d3629bdeb9c0d92d3d3 /mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
parent9df19ce40281551bd348b262a131085cf98dadf5 (diff)
[mlir][SCF] Add `scf.for` bufferization preprocessing passusers/matthias-springer/scf_bufferization_preprocessing
Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors. This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message: ``` error: Yield operand #0 is not equivalent to the corresponding iter bbArg ```
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp33
1 files changed, 33 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
index 21c618ab633f..727c4fc7c639 100644
--- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp
@@ -17,6 +17,7 @@
namespace mlir {
#define GEN_PASS_DEF_SCFBUFFERIZE
+#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING
#include "mlir/Dialect/SCF/Transforms/Passes.h.inc"
} // namespace mlir
@@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> {
return signalPassFailure();
};
};
+
+struct SCFLoopBufferizationPreprocessingPass
+ : public impl::SCFLoopBufferizationPreprocessingBase<
+ SCFLoopBufferizationPreprocessingPass> {
+ void runOnOperation() override {
+ OpBuilder builder(getOperation()->getContext());
+ getOperation()->walk([&](scf::YieldOp yieldOp) {
+ builder.setInsertionPoint(yieldOp);
+ // TODO: Support scf.while.
+ auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp());
+ if (!forOp)
+ return WalkResult::skip();
+ for (OpOperand &operand : yieldOp->getOpOperands()) {
+ auto tensorType = dyn_cast<TensorType>(operand.get().getType());
+ if (!tensorType)
+ continue;
+ auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()];
+ Value materialized =
+ builder
+ .create<bufferization::MaterializeInDestinationOp>(
+ yieldOp.getLoc(), tensorType, operand.get(), bbArg)
+ .getResult();
+ operand.set(materialized);
+ }
+ return WalkResult::advance();
+ });
+ }
+};
} // namespace
std::unique_ptr<Pass> mlir::createSCFBufferizePass() {
return std::make_unique<SCFBufferizePass>();
}
+
+std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() {
+ return std::make_unique<SCFLoopBufferizationPreprocessingPass>();
+}