diff options
| author | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
|---|---|---|
| committer | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
| commit | 38500d63e14ce340236840f60d356cdefb56a52c (patch) | |
| tree | 17edbec446ce9b50d2f215a483b83afb293a635d /flang/lib/Optimizer/Transforms/AbstractResult.cpp | |
| parent | 1a3d5daaef7a6a63448a497da3eff7fc9e23df26 (diff) | |
| parent | 27f30029741ecf023baece7b3dde1ff9011ffefc (diff) | |
Merge branch 'main' into users/meinersbur/flang_runtime_split-headersusers/meinersbur/flang_runtime_split-headers
Diffstat (limited to 'flang/lib/Optimizer/Transforms/AbstractResult.cpp')
| -rw-r--r-- | flang/lib/Optimizer/Transforms/AbstractResult.cpp | 138 |
1 files changed, 78 insertions, 60 deletions
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 2eca349110f3..b0327cc10e9d 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -234,6 +234,60 @@ public: } }; +template <typename OpTy> +static mlir::LogicalResult +processReturnLikeOp(OpTy ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) { + auto loc = ret.getLoc(); + rewriter.setInsertionPoint(ret); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->template getParentOfType<mlir::ModuleOp>(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); + } + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create<fir::StoreOp>(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + return mlir::success(); +} + class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { public: using OpRewritePattern::OpRewritePattern; @@ -242,55 +296,23 @@ public: llvm::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { - auto loc = ret.getLoc(); - rewriter.setInsertionPoint(ret); - mlir::Value resultValue = ret.getOperand(0); - fir::LoadOp resultLoad; - mlir::Value resultStorage; - // Identify result local storage. - if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { - resultLoad = load; - resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) - resultStorage = declare.getMemref(); - } - // Replace old local storage with new storage argument, unless - // the derived type is C_PTR/C_FUN_PTR, in which case the return - // type is updated to return void* (no new argument is passed). - if (fir::isa_builtin_cptr_type(resultValue.getType())) { - auto module = ret->getParentOfType<mlir::ModuleOp>(); - FirOpBuilder builder(rewriter, module); - mlir::Value cptr = resultValue; - if (resultLoad) { - // Replace whole derived type load by component load. - cptr = resultLoad.getMemref(); - rewriter.setInsertionPoint(resultLoad); - } - mlir::Value newResultValue = - fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); - newResultValue = builder.createConvert( - loc, getVoidPtrType(ret.getContext()), newResultValue); - rewriter.setInsertionPoint(ret); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( - ret, mlir::ValueRange{newResultValue}); - } else if (resultStorage) { - resultStorage.replaceAllUsesWith(newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } else { - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result - // storage. at the return point. - rewriter.create<fir::StoreOp>(loc, resultValue, newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } - // Delete result old local storage if unused. - if (resultStorage) - if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); - return mlir::success(); + return processReturnLikeOp(ret, newArg, rewriter); + } + +private: + mlir::Value newArg; +}; + +class GPUReturnOpConversion + : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { +public: + using OpRewritePattern::OpRewritePattern; + GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} + llvm::LogicalResult + matchAndRewrite(mlir::gpu::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + return processReturnLikeOp(ret, newArg, rewriter); } private: @@ -373,6 +395,9 @@ public: patterns.insert<ReturnOpConversion>(context, newArg); target.addDynamicallyLegalOp<mlir::func::ReturnOp>( [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); + patterns.insert<GPUReturnOpConversion>(context, newArg); + target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( + [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); assert(func.getFunctionType() == getNewFunctionType(funcTy, shouldBoxResult)); } else { @@ -460,17 +485,10 @@ public: const bool shouldBoxResult = this->passResultAsBox.getValue(); mlir::TypeSwitch<mlir::Operation *, void>(op) - .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { - runOnSpecificOperation(op, shouldBoxResult, patterns, target); - }) - .Case<mlir::gpu::GPUModuleOp>([&](auto op) { - auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(*op); - for (auto funcOp : gpuMod.template getOps<mlir::func::FuncOp>()) - runOnSpecificOperation(funcOp, shouldBoxResult, patterns, target); - for (auto gpuFuncOp : gpuMod.template getOps<mlir::gpu::GPUFuncOp>()) - runOnSpecificOperation(gpuFuncOp, shouldBoxResult, patterns, - target); - }); + .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>( + [&](auto op) { + runOnSpecificOperation(op, shouldBoxResult, patterns, target); + }); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, |
