diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/IR/SCF.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/IR/SCF.cpp | 31 |
1 files changed, 14 insertions, 17 deletions
diff --git a/mlir/lib/Dialect/SCF/IR/SCF.cpp b/mlir/lib/Dialect/SCF/IR/SCF.cpp index eded1c394f12..83ae79ce4826 100644 --- a/mlir/lib/Dialect/SCF/IR/SCF.cpp +++ b/mlir/lib/Dialect/SCF/IR/SCF.cpp @@ -839,8 +839,7 @@ mlir::scf::replaceAndCastForOpIterArg(RewriterBase &rewriter, scf::ForOp forOp, namespace { // Fold away ForOp iter arguments when: // 1) The op yields the iter arguments. -// 2) The iter arguments have no use and the corresponding outer region -// iterators (inputs) are yielded. +// 2) The argument's corresponding outer region iterators (inputs) are yielded. // 3) The iter arguments have no use and the corresponding (operation) results // have no use. // @@ -872,30 +871,28 @@ struct ForOpIterArgsFolder : public OpRewritePattern<scf::ForOp> { newIterArgs.reserve(forOp.getInitArgs().size()); newYieldValues.reserve(numResults); newResultValues.reserve(numResults); - for (auto it : llvm::zip(forOp.getInitArgs(), // iter from outside - forOp.getRegionIterArgs(), // iter inside region - forOp.getResults(), // op results - forOp.getYieldedValues() // iter yield - )) { + for (auto [init, arg, result, yielded] : + llvm::zip(forOp.getInitArgs(), // iter from outside + forOp.getRegionIterArgs(), // iter inside region + forOp.getResults(), // op results + forOp.getYieldedValues() // iter yield + )) { // Forwarded is `true` when: // 1) The region `iter` argument is yielded. - // 2) The region `iter` argument has no use, and the corresponding iter - // operand (input) is yielded. + // 2) The region `iter` argument the corresponding input is yielded. // 3) The region `iter` argument has no use, and the corresponding op // result has no use. - bool forwarded = ((std::get<1>(it) == std::get<3>(it)) || - (std::get<1>(it).use_empty() && - (std::get<0>(it) == std::get<3>(it) || - std::get<2>(it).use_empty()))); + bool forwarded = (arg == yielded) || (init == yielded) || + (arg.use_empty() && result.use_empty()); keepMask.push_back(!forwarded); canonicalize |= forwarded; if (forwarded) { - newBlockTransferArgs.push_back(std::get<0>(it)); - newResultValues.push_back(std::get<0>(it)); + newBlockTransferArgs.push_back(init); + newResultValues.push_back(init); continue; } - newIterArgs.push_back(std::get<0>(it)); - newYieldValues.push_back(std::get<3>(it)); + newIterArgs.push_back(init); + newYieldValues.push_back(yielded); newBlockTransferArgs.push_back(Value()); // placeholder with null value newResultValues.push_back(Value()); // placeholder with null value } |
