diff options
Diffstat (limited to 'flang/lib/Optimizer/Transforms/FIRToSCF.cpp')
| -rw-r--r-- | flang/lib/Optimizer/Transforms/FIRToSCF.cpp | 74 |
1 files changed, 71 insertions, 3 deletions
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp index 79ed85fa6060..2bca0d98ec68 100644 --- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp +++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp @@ -36,7 +36,7 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { mlir::Value high = doLoopOp.getUpperBound(); assert(low && high && "must be a Value"); mlir::Value step = doLoopOp.getStep(); - llvm::SmallVector<mlir::Value> iterArgs; + mlir::SmallVector<mlir::Value> iterArgs; if (hasFinalValue) iterArgs.push_back(low); iterArgs.append(doLoopOp.getIterOperands().begin(), @@ -88,6 +88,73 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> { } }; +struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> { + using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(fir::IterWhileOp iterWhileOp, + mlir::PatternRewriter &rewriter) const override { + + mlir::Location loc = iterWhileOp.getLoc(); + mlir::Value lowerBound = iterWhileOp.getLowerBound(); + mlir::Value upperBound = iterWhileOp.getUpperBound(); + mlir::Value step = iterWhileOp.getStep(); + + mlir::Value okInit = iterWhileOp.getIterateIn(); + mlir::ValueRange iterArgs = iterWhileOp.getInitArgs(); + + mlir::SmallVector<mlir::Value> initVals; + initVals.push_back(lowerBound); + initVals.push_back(okInit); + initVals.append(iterArgs.begin(), iterArgs.end()); + + mlir::SmallVector<mlir::Type> loopTypes; + loopTypes.push_back(lowerBound.getType()); + loopTypes.push_back(okInit.getType()); + for (auto val : iterArgs) + loopTypes.push_back(val.getType()); + + auto scfWhileOp = + mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals); + + auto &beforeBlock = *rewriter.createBlock( + &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes, + mlir::SmallVector<mlir::Location>(loopTypes.size(), loc)); + + mlir::Region::BlockArgListType argsInBefore = + scfWhileOp.getBefore().getArguments(); + auto ivInBefore = argsInBefore[0]; + auto earlyExitInBefore = argsInBefore[1]; + + rewriter.setInsertionPointToStart(&beforeBlock); + + mlir::Value inductionCmp = mlir::arith::CmpIOp::create( + rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound); + mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp, + earlyExitInBefore); + + mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore); + + rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(), + scfWhileOp.getAfter().begin()); + + auto *afterBody = scfWhileOp.getAfterBody(); + auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator()); + mlir::SmallVector<mlir::Value> results(resultOp->getOperands()); + mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0]; + + rewriter.setInsertionPointToStart(afterBody); + results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step); + + rewriter.setInsertionPointToEnd(afterBody); + rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results); + + scfWhileOp->setAttrs(iterWhileOp->getAttrs()); + rewriter.replaceOp(iterWhileOp, scfWhileOp); + return mlir::success(); + } +}; + void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter, mlir::Block &srcBlock, mlir::Block &dstBlock) { mlir::Operation *srcTerminator = srcBlock.getTerminator(); @@ -132,9 +199,10 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> { void FIRToSCFPass::runOnOperation() { mlir::RewritePatternSet patterns(&getContext()); - patterns.add<DoLoopConversion, IfConversion>(patterns.getContext()); + patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>( + patterns.getContext()); mlir::ConversionTarget target(getContext()); - target.addIllegalOp<fir::DoLoopOp, fir::IfOp>(); + target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>(); target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; }); if (failed( applyPartialConversion(getOperation(), target, std::move(patterns)))) |
