summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/AbstractResult.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/AbstractResult.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/AbstractResult.cpp138
1 files changed, 78 insertions, 60 deletions
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
index 2eca349110f3..b0327cc10e9d 100644
--- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp
+++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp
@@ -234,6 +234,60 @@ public:
}
};
+template <typename OpTy>
+static mlir::LogicalResult
+processReturnLikeOp(OpTy ret, mlir::Value newArg,
+ mlir::PatternRewriter &rewriter) {
+ auto loc = ret.getLoc();
+ rewriter.setInsertionPoint(ret);
+ mlir::Value resultValue = ret.getOperand(0);
+ fir::LoadOp resultLoad;
+ mlir::Value resultStorage;
+ // Identify result local storage.
+ if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
+ resultLoad = load;
+ resultStorage = load.getMemref();
+ // The result alloca may be behind a fir.declare, if any.
+ if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
+ resultStorage = declare.getMemref();
+ }
+ // Replace old local storage with new storage argument, unless
+ // the derived type is C_PTR/C_FUN_PTR, in which case the return
+ // type is updated to return void* (no new argument is passed).
+ if (fir::isa_builtin_cptr_type(resultValue.getType())) {
+ auto module = ret->template getParentOfType<mlir::ModuleOp>();
+ FirOpBuilder builder(rewriter, module);
+ mlir::Value cptr = resultValue;
+ if (resultLoad) {
+ // Replace whole derived type load by component load.
+ cptr = resultLoad.getMemref();
+ rewriter.setInsertionPoint(resultLoad);
+ }
+ mlir::Value newResultValue =
+ fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
+ newResultValue = builder.createConvert(
+ loc, getVoidPtrType(ret.getContext()), newResultValue);
+ rewriter.setInsertionPoint(ret);
+ rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue});
+ } else if (resultStorage) {
+ resultStorage.replaceAllUsesWith(newArg);
+ rewriter.replaceOpWithNewOp<OpTy>(ret);
+ } else {
+ // The result storage may have been optimized out by a memory to
+ // register pass, this is possible for fir.box results, or fir.record
+ // with no length parameters. Simply store the result in the result
+ // storage. at the return point.
+ rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
+ rewriter.replaceOpWithNewOp<OpTy>(ret);
+ }
+ // Delete result old local storage if unused.
+ if (resultStorage)
+ if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
+ if (alloc->use_empty())
+ rewriter.eraseOp(alloc);
+ return mlir::success();
+}
+
class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> {
public:
using OpRewritePattern::OpRewritePattern;
@@ -242,55 +296,23 @@ public:
llvm::LogicalResult
matchAndRewrite(mlir::func::ReturnOp ret,
mlir::PatternRewriter &rewriter) const override {
- auto loc = ret.getLoc();
- rewriter.setInsertionPoint(ret);
- mlir::Value resultValue = ret.getOperand(0);
- fir::LoadOp resultLoad;
- mlir::Value resultStorage;
- // Identify result local storage.
- if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) {
- resultLoad = load;
- resultStorage = load.getMemref();
- // The result alloca may be behind a fir.declare, if any.
- if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>())
- resultStorage = declare.getMemref();
- }
- // Replace old local storage with new storage argument, unless
- // the derived type is C_PTR/C_FUN_PTR, in which case the return
- // type is updated to return void* (no new argument is passed).
- if (fir::isa_builtin_cptr_type(resultValue.getType())) {
- auto module = ret->getParentOfType<mlir::ModuleOp>();
- FirOpBuilder builder(rewriter, module);
- mlir::Value cptr = resultValue;
- if (resultLoad) {
- // Replace whole derived type load by component load.
- cptr = resultLoad.getMemref();
- rewriter.setInsertionPoint(resultLoad);
- }
- mlir::Value newResultValue =
- fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr);
- newResultValue = builder.createConvert(
- loc, getVoidPtrType(ret.getContext()), newResultValue);
- rewriter.setInsertionPoint(ret);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(
- ret, mlir::ValueRange{newResultValue});
- } else if (resultStorage) {
- resultStorage.replaceAllUsesWith(newArg);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
- } else {
- // The result storage may have been optimized out by a memory to
- // register pass, this is possible for fir.box results, or fir.record
- // with no length parameters. Simply store the result in the result
- // storage. at the return point.
- rewriter.create<fir::StoreOp>(loc, resultValue, newArg);
- rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret);
- }
- // Delete result old local storage if unused.
- if (resultStorage)
- if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>())
- if (alloc->use_empty())
- rewriter.eraseOp(alloc);
- return mlir::success();
+ return processReturnLikeOp(ret, newArg, rewriter);
+ }
+
+private:
+ mlir::Value newArg;
+};
+
+class GPUReturnOpConversion
+ : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> {
+public:
+ using OpRewritePattern::OpRewritePattern;
+ GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg)
+ : OpRewritePattern(context), newArg{newArg} {}
+ llvm::LogicalResult
+ matchAndRewrite(mlir::gpu::ReturnOp ret,
+ mlir::PatternRewriter &rewriter) const override {
+ return processReturnLikeOp(ret, newArg, rewriter);
}
private:
@@ -373,6 +395,9 @@ public:
patterns.insert<ReturnOpConversion>(context, newArg);
target.addDynamicallyLegalOp<mlir::func::ReturnOp>(
[](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); });
+ patterns.insert<GPUReturnOpConversion>(context, newArg);
+ target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>(
+ [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); });
assert(func.getFunctionType() ==
getNewFunctionType(funcTy, shouldBoxResult));
} else {
@@ -460,17 +485,10 @@ public:
const bool shouldBoxResult = this->passResultAsBox.getValue();
mlir::TypeSwitch<mlir::Operation *, void>(op)
- .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) {
- runOnSpecificOperation(op, shouldBoxResult, patterns, target);
- })
- .Case<mlir::gpu::GPUModuleOp>([&](auto op) {
- auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(*op);
- for (auto funcOp : gpuMod.template getOps<mlir::func::FuncOp>())
- runOnSpecificOperation(funcOp, shouldBoxResult, patterns, target);
- for (auto gpuFuncOp : gpuMod.template getOps<mlir::gpu::GPUFuncOp>())
- runOnSpecificOperation(gpuFuncOp, shouldBoxResult, patterns,
- target);
- });
+ .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>(
+ [&](auto op) {
+ runOnSpecificOperation(op, shouldBoxResult, patterns, target);
+ });
// Convert the calls and, if needed, the ReturnOp in the function body.
target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect,