summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/FIRToSCF.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/FIRToSCF.cpp74
1 files changed, 71 insertions, 3 deletions
diff --git a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
index 79ed85fa6060..2bca0d98ec68 100644
--- a/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
+++ b/flang/lib/Optimizer/Transforms/FIRToSCF.cpp
@@ -36,7 +36,7 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
mlir::Value high = doLoopOp.getUpperBound();
assert(low && high && "must be a Value");
mlir::Value step = doLoopOp.getStep();
- llvm::SmallVector<mlir::Value> iterArgs;
+ mlir::SmallVector<mlir::Value> iterArgs;
if (hasFinalValue)
iterArgs.push_back(low);
iterArgs.append(doLoopOp.getIterOperands().begin(),
@@ -88,6 +88,73 @@ struct DoLoopConversion : public mlir::OpRewritePattern<fir::DoLoopOp> {
}
};
+struct IterWhileConversion : public mlir::OpRewritePattern<fir::IterWhileOp> {
+ using OpRewritePattern<fir::IterWhileOp>::OpRewritePattern;
+
+ mlir::LogicalResult
+ matchAndRewrite(fir::IterWhileOp iterWhileOp,
+ mlir::PatternRewriter &rewriter) const override {
+
+ mlir::Location loc = iterWhileOp.getLoc();
+ mlir::Value lowerBound = iterWhileOp.getLowerBound();
+ mlir::Value upperBound = iterWhileOp.getUpperBound();
+ mlir::Value step = iterWhileOp.getStep();
+
+ mlir::Value okInit = iterWhileOp.getIterateIn();
+ mlir::ValueRange iterArgs = iterWhileOp.getInitArgs();
+
+ mlir::SmallVector<mlir::Value> initVals;
+ initVals.push_back(lowerBound);
+ initVals.push_back(okInit);
+ initVals.append(iterArgs.begin(), iterArgs.end());
+
+ mlir::SmallVector<mlir::Type> loopTypes;
+ loopTypes.push_back(lowerBound.getType());
+ loopTypes.push_back(okInit.getType());
+ for (auto val : iterArgs)
+ loopTypes.push_back(val.getType());
+
+ auto scfWhileOp =
+ mlir::scf::WhileOp::create(rewriter, loc, loopTypes, initVals);
+
+ auto &beforeBlock = *rewriter.createBlock(
+ &scfWhileOp.getBefore(), scfWhileOp.getBefore().end(), loopTypes,
+ mlir::SmallVector<mlir::Location>(loopTypes.size(), loc));
+
+ mlir::Region::BlockArgListType argsInBefore =
+ scfWhileOp.getBefore().getArguments();
+ auto ivInBefore = argsInBefore[0];
+ auto earlyExitInBefore = argsInBefore[1];
+
+ rewriter.setInsertionPointToStart(&beforeBlock);
+
+ mlir::Value inductionCmp = mlir::arith::CmpIOp::create(
+ rewriter, loc, mlir::arith::CmpIPredicate::sle, ivInBefore, upperBound);
+ mlir::Value cond = mlir::arith::AndIOp::create(rewriter, loc, inductionCmp,
+ earlyExitInBefore);
+
+ mlir::scf::ConditionOp::create(rewriter, loc, cond, argsInBefore);
+
+ rewriter.moveBlockBefore(iterWhileOp.getBody(), &scfWhileOp.getAfter(),
+ scfWhileOp.getAfter().begin());
+
+ auto *afterBody = scfWhileOp.getAfterBody();
+ auto resultOp = mlir::cast<fir::ResultOp>(afterBody->getTerminator());
+ mlir::SmallVector<mlir::Value> results(resultOp->getOperands());
+ mlir::Value ivInAfter = scfWhileOp.getAfterArguments()[0];
+
+ rewriter.setInsertionPointToStart(afterBody);
+ results[0] = mlir::arith::AddIOp::create(rewriter, loc, ivInAfter, step);
+
+ rewriter.setInsertionPointToEnd(afterBody);
+ rewriter.replaceOpWithNewOp<mlir::scf::YieldOp>(resultOp, results);
+
+ scfWhileOp->setAttrs(iterWhileOp->getAttrs());
+ rewriter.replaceOp(iterWhileOp, scfWhileOp);
+ return mlir::success();
+ }
+};
+
void copyBlockAndTransformResult(mlir::PatternRewriter &rewriter,
mlir::Block &srcBlock, mlir::Block &dstBlock) {
mlir::Operation *srcTerminator = srcBlock.getTerminator();
@@ -132,9 +199,10 @@ struct IfConversion : public mlir::OpRewritePattern<fir::IfOp> {
void FIRToSCFPass::runOnOperation() {
mlir::RewritePatternSet patterns(&getContext());
- patterns.add<DoLoopConversion, IfConversion>(patterns.getContext());
+ patterns.add<DoLoopConversion, IterWhileConversion, IfConversion>(
+ patterns.getContext());
mlir::ConversionTarget target(getContext());
- target.addIllegalOp<fir::DoLoopOp, fir::IfOp>();
+ target.addIllegalOp<fir::DoLoopOp, fir::IterWhileOp, fir::IfOp>();
target.markUnknownOpDynamicallyLegal([](mlir::Operation *) { return true; });
if (failed(
applyPartialConversion(getOperation(), target, std::move(patterns))))