summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/IR/SCF.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/IR/SCF.cpp31
1 files changed, 14 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp
index eded1c394f12..83ae79ce4826 100644
--- a/mlir/lib/Dialect/SCF/IR/SCF.cpp
+++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp
@@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp,
namespace {
// Fold away ForOp iter arguments when:
// 1) The op yields the iter arguments.
-// 2) The iter arguments have no use and the corresponding outer region
-// iterators (inputs) are yielded.
+// 2) The argument's corresponding outer region iterators (inputs) are yielded.
// 3) The iter arguments have no use and the corresponding (operation) results
// have no use.
//
@@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> {
newIterArgs.reserve(forOp.getInitArgs().size());
newYieldValues.reserve(numResults);
newResultValues.reserve(numResults);
- for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside
- forOp.getRegionIterArgs(), // iter inside region
- forOp.getResults(), // op results
- forOp.getYieldedValues() // iter yield
- )) {
+ for (auto [init, arg, result, yielded] :
+ llvm::zip(forOp.getInitArgs(), // iter from outside
+ forOp.getRegionIterArgs(), // iter inside region
+ forOp.getResults(), // op results
+ forOp.getYieldedValues() // iter yield
+ )) {
// Forwarded is `true` when:
// 1) The region `iter` argument is yielded.
- // 2) The region `iter` argument has no use, and the corresponding iter
- // operand (input) is yielded.
+ // 2) The region `iter` argument the corresponding input is yielded.
// 3) The region `iter` argument has no use, and the corresponding op
// result has no use.
- bool forwarded = ((std::get<1>(it) == std::get<3>(it)) ||
- (std::get<1>(it).use_empty() &&
- (std::get<0>(it) == std::get<3>(it) ||
- std::get<2>(it).use_empty())));
+ bool forwarded = (arg == yielded) || (init == yielded) ||
+ (arg.use_empty() && result.use_empty());
keepMask.push_back(!forwarded);
canonicalize |= forwarded;
if (forwarded) {
- newBlockTransferArgs.push_back(std::get<0>(it));
- newResultValues.push_back(std::get<0>(it));
+ newBlockTransferArgs.push_back(init);
+ newResultValues.push_back(init);
continue;
}
- newIterArgs.push_back(std::get<0>(it));
- newYieldValues.push_back(std::get<3>(it));
+ newIterArgs.push_back(init);
+ newYieldValues.push_back(yielded);
newBlockTransferArgs.push_back(Value()); // placeholder with null value
newResultValues.push_back(Value()); // placeholder with null value
}