diff options
| author | Matthias Springer <springerm@google.com> | 2024-04-04 03:07:43 +0000 |
|---|---|---|
| committer | Matthias Springer <springerm@google.com> | 2024-04-04 03:07:45 +0000 |
| commit | 4c819864c5edbdb8137451c0a0b6a97240a17008 (patch) | |
| tree | 3617ed9d8ff7712b6c445d3629bdeb9c0d92d3d3 /mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp | |
| parent | 9df19ce40281551bd348b262a131085cf98dadf5 (diff) | |
[mlir][SCF] Add `scf.for` bufferization preprocessing passusers/matthias-springer/scf_bufferization_preprocessing
Add a bufferization preprocessing pass for `scf.for` loops to support loops where a yielded tensor value does not bufferize to the equivalent corresponding iter_arg buffer. This preprocessing works around a limitation of `scf.for` bufferization by inserting additional buffer copies for yielded tensors.
This preprocessing can be used to support most cases where One-Shot Bufferize fails to bufferize the IR with the following error message:
```
error: Yield operand #0 is not equivalent to the corresponding iter bbArg
```
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp | 33 |
1 files changed, 33 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp index 21c618ab633f..727c4fc7c639 100644 --- a/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/Bufferize.cpp @@ -17,6 +17,7 @@ namespace mlir { #define GEN_PASS_DEF_SCFBUFFERIZE +#define GEN_PASS_DEF_SCFLOOPBUFFERIZATIONPREPROCESSING #include "mlir/Dialect/SCF/Transforms/Passes.h.inc" } // namespace mlir @@ -40,8 +41,40 @@ struct SCFBufferizePass : public impl::SCFBufferizeBase<SCFBufferizePass> { return signalPassFailure(); }; }; + +struct SCFLoopBufferizationPreprocessingPass + : public impl::SCFLoopBufferizationPreprocessingBase< + SCFLoopBufferizationPreprocessingPass> { + void runOnOperation() override { + OpBuilder builder(getOperation()->getContext()); + getOperation()->walk([&](scf::YieldOp yieldOp) { + builder.setInsertionPoint(yieldOp); + // TODO: Support scf.while. + auto forOp = dyn_cast<scf::ForOp>(yieldOp->getParentOp()); + if (!forOp) + return WalkResult::skip(); + for (OpOperand &operand : yieldOp->getOpOperands()) { + auto tensorType = dyn_cast<TensorType>(operand.get().getType()); + if (!tensorType) + continue; + auto bbArg = forOp.getRegionIterArgs()[operand.getOperandNumber()]; + Value materialized = + builder + .create<bufferization::MaterializeInDestinationOp>( + yieldOp.getLoc(), tensorType, operand.get(), bbArg) + .getResult(); + operand.set(materialized); + } + return WalkResult::advance(); + }); + } +}; } // namespace std::unique_ptr<Pass> mlir::createSCFBufferizePass() { return std::make_unique<SCFBufferizePass>(); } + +std::unique_ptr<Pass> mlir::createSCFLoopBufferizationPreprocessingPass() { + return std::make_unique<SCFLoopBufferizationPreprocessingPass>(); +} |
