diff options
| author | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
|---|---|---|
| committer | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
| commit | 38500d63e14ce340236840f60d356cdefb56a52c (patch) | |
| tree | 17edbec446ce9b50d2f215a483b83afb293a635d /flang/lib/Optimizer/Transforms | |
| parent | 1a3d5daaef7a6a63448a497da3eff7fc9e23df26 (diff) | |
| parent | 27f30029741ecf023baece7b3dde1ff9011ffefc (diff) | |
Merge branch 'main' into users/meinersbur/flang_runtime_split-headersusers/meinersbur/flang_runtime_split-headers
Diffstat (limited to 'flang/lib/Optimizer/Transforms')
18 files changed, 358 insertions, 267 deletions
diff --git a/flang/lib/Optimizer/Transforms/AbstractResult.cpp b/flang/lib/Optimizer/Transforms/AbstractResult.cpp index 2eca349110f3..b0327cc10e9d 100644 --- a/flang/lib/Optimizer/Transforms/AbstractResult.cpp +++ b/flang/lib/Optimizer/Transforms/AbstractResult.cpp @@ -234,6 +234,60 @@ public: } }; +template <typename OpTy> +static mlir::LogicalResult +processReturnLikeOp(OpTy ret, mlir::Value newArg, + mlir::PatternRewriter &rewriter) { + auto loc = ret.getLoc(); + rewriter.setInsertionPoint(ret); + mlir::Value resultValue = ret.getOperand(0); + fir::LoadOp resultLoad; + mlir::Value resultStorage; + // Identify result local storage. + if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { + resultLoad = load; + resultStorage = load.getMemref(); + // The result alloca may be behind a fir.declare, if any. + if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) + resultStorage = declare.getMemref(); + } + // Replace old local storage with new storage argument, unless + // the derived type is C_PTR/C_FUN_PTR, in which case the return + // type is updated to return void* (no new argument is passed). + if (fir::isa_builtin_cptr_type(resultValue.getType())) { + auto module = ret->template getParentOfType<mlir::ModuleOp>(); + FirOpBuilder builder(rewriter, module); + mlir::Value cptr = resultValue; + if (resultLoad) { + // Replace whole derived type load by component load. + cptr = resultLoad.getMemref(); + rewriter.setInsertionPoint(resultLoad); + } + mlir::Value newResultValue = + fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); + newResultValue = builder.createConvert( + loc, getVoidPtrType(ret.getContext()), newResultValue); + rewriter.setInsertionPoint(ret); + rewriter.replaceOpWithNewOp<OpTy>(ret, mlir::ValueRange{newResultValue}); + } else if (resultStorage) { + resultStorage.replaceAllUsesWith(newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } else { + // The result storage may have been optimized out by a memory to + // register pass, this is possible for fir.box results, or fir.record + // with no length parameters. Simply store the result in the result + // storage. at the return point. + rewriter.create<fir::StoreOp>(loc, resultValue, newArg); + rewriter.replaceOpWithNewOp<OpTy>(ret); + } + // Delete result old local storage if unused. + if (resultStorage) + if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) + if (alloc->use_empty()) + rewriter.eraseOp(alloc); + return mlir::success(); +} + class ReturnOpConversion : public mlir::OpRewritePattern<mlir::func::ReturnOp> { public: using OpRewritePattern::OpRewritePattern; @@ -242,55 +296,23 @@ public: llvm::LogicalResult matchAndRewrite(mlir::func::ReturnOp ret, mlir::PatternRewriter &rewriter) const override { - auto loc = ret.getLoc(); - rewriter.setInsertionPoint(ret); - mlir::Value resultValue = ret.getOperand(0); - fir::LoadOp resultLoad; - mlir::Value resultStorage; - // Identify result local storage. - if (auto load = resultValue.getDefiningOp<fir::LoadOp>()) { - resultLoad = load; - resultStorage = load.getMemref(); - // The result alloca may be behind a fir.declare, if any. - if (auto declare = resultStorage.getDefiningOp<fir::DeclareOp>()) - resultStorage = declare.getMemref(); - } - // Replace old local storage with new storage argument, unless - // the derived type is C_PTR/C_FUN_PTR, in which case the return - // type is updated to return void* (no new argument is passed). - if (fir::isa_builtin_cptr_type(resultValue.getType())) { - auto module = ret->getParentOfType<mlir::ModuleOp>(); - FirOpBuilder builder(rewriter, module); - mlir::Value cptr = resultValue; - if (resultLoad) { - // Replace whole derived type load by component load. - cptr = resultLoad.getMemref(); - rewriter.setInsertionPoint(resultLoad); - } - mlir::Value newResultValue = - fir::factory::genCPtrOrCFunptrValue(builder, loc, cptr); - newResultValue = builder.createConvert( - loc, getVoidPtrType(ret.getContext()), newResultValue); - rewriter.setInsertionPoint(ret); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>( - ret, mlir::ValueRange{newResultValue}); - } else if (resultStorage) { - resultStorage.replaceAllUsesWith(newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } else { - // The result storage may have been optimized out by a memory to - // register pass, this is possible for fir.box results, or fir.record - // with no length parameters. Simply store the result in the result - // storage. at the return point. - rewriter.create<fir::StoreOp>(loc, resultValue, newArg); - rewriter.replaceOpWithNewOp<mlir::func::ReturnOp>(ret); - } - // Delete result old local storage if unused. - if (resultStorage) - if (auto alloc = resultStorage.getDefiningOp<fir::AllocaOp>()) - if (alloc->use_empty()) - rewriter.eraseOp(alloc); - return mlir::success(); + return processReturnLikeOp(ret, newArg, rewriter); + } + +private: + mlir::Value newArg; +}; + +class GPUReturnOpConversion + : public mlir::OpRewritePattern<mlir::gpu::ReturnOp> { +public: + using OpRewritePattern::OpRewritePattern; + GPUReturnOpConversion(mlir::MLIRContext *context, mlir::Value newArg) + : OpRewritePattern(context), newArg{newArg} {} + llvm::LogicalResult + matchAndRewrite(mlir::gpu::ReturnOp ret, + mlir::PatternRewriter &rewriter) const override { + return processReturnLikeOp(ret, newArg, rewriter); } private: @@ -373,6 +395,9 @@ public: patterns.insert<ReturnOpConversion>(context, newArg); target.addDynamicallyLegalOp<mlir::func::ReturnOp>( [](mlir::func::ReturnOp ret) { return ret.getOperands().empty(); }); + patterns.insert<GPUReturnOpConversion>(context, newArg); + target.addDynamicallyLegalOp<mlir::gpu::ReturnOp>( + [](mlir::gpu::ReturnOp ret) { return ret.getOperands().empty(); }); assert(func.getFunctionType() == getNewFunctionType(funcTy, shouldBoxResult)); } else { @@ -460,17 +485,10 @@ public: const bool shouldBoxResult = this->passResultAsBox.getValue(); mlir::TypeSwitch<mlir::Operation *, void>(op) - .Case<mlir::func::FuncOp, fir::GlobalOp>([&](auto op) { - runOnSpecificOperation(op, shouldBoxResult, patterns, target); - }) - .Case<mlir::gpu::GPUModuleOp>([&](auto op) { - auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(*op); - for (auto funcOp : gpuMod.template getOps<mlir::func::FuncOp>()) - runOnSpecificOperation(funcOp, shouldBoxResult, patterns, target); - for (auto gpuFuncOp : gpuMod.template getOps<mlir::gpu::GPUFuncOp>()) - runOnSpecificOperation(gpuFuncOp, shouldBoxResult, patterns, - target); - }); + .Case<mlir::func::FuncOp, fir::GlobalOp, mlir::gpu::GPUFuncOp>( + [&](auto op) { + runOnSpecificOperation(op, shouldBoxResult, patterns, target); + }); // Convert the calls and, if needed, the ReturnOp in the function body. target.addLegalDialect<fir::FIROpsDialect, mlir::arith::ArithDialect, diff --git a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp index f1e70875de0b..e6fc2ed992e3 100644 --- a/flang/lib/Optimizer/Transforms/AddAliasTags.cpp +++ b/flang/lib/Optimizer/Transforms/AddAliasTags.cpp @@ -227,7 +227,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, source.kind == fir::AliasAnalysis::SourceKind::Argument) { LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to dummy argument at " << *op << "\n"); - std::string name = getFuncArgName(source.origin.u.get<mlir::Value>()); + std::string name = getFuncArgName(llvm::cast<mlir::Value>(source.origin.u)); if (!name.empty()) tag = state.getFuncTreeWithScope(func, scopeOp) .dummyArgDataTree.getTag(name); @@ -240,7 +240,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, } else if (enableGlobals && source.kind == fir::AliasAnalysis::SourceKind::Global && !source.isBoxData()) { - mlir::SymbolRefAttr glbl = source.origin.u.get<mlir::SymbolRefAttr>(); + mlir::SymbolRefAttr glbl = llvm::cast<mlir::SymbolRefAttr>(source.origin.u); const char *name = glbl.getRootReference().data(); LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to global " << name << " at " << *op << "\n"); @@ -250,8 +250,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, } else if (enableDirect && source.kind == fir::AliasAnalysis::SourceKind::Global && source.isBoxData()) { - if (source.origin.u.is<mlir::SymbolRefAttr>()) { - mlir::SymbolRefAttr glbl = source.origin.u.get<mlir::SymbolRefAttr>(); + if (auto glbl = llvm::dyn_cast<mlir::SymbolRefAttr>(source.origin.u)) { const char *name = glbl.getRootReference().data(); LLVM_DEBUG(llvm::dbgs().indent(2) << "Found reference to direct " << name << " at " << *op << "\n"); @@ -269,7 +268,7 @@ void AddAliasTagsPass::runOnAliasInterface(fir::FirAliasTagOpInterface op, source.kind == fir::AliasAnalysis::SourceKind::Allocate) { std::optional<llvm::StringRef> name; mlir::Operation *sourceOp = - source.origin.u.get<mlir::Value>().getDefiningOp(); + llvm::cast<mlir::Value>(source.origin.u).getDefiningOp(); if (auto alloc = mlir::dyn_cast_or_null<fir::AllocaOp>(sourceOp)) name = alloc.getUniqName(); else if (auto alloc = mlir::dyn_cast_or_null<fir::AllocMemOp>(sourceOp)) diff --git a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp index fd58375da618..fab1f0299ede 100644 --- a/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp +++ b/flang/lib/Optimizer/Transforms/AlgebraicSimplification.cpp @@ -39,8 +39,7 @@ struct AlgebraicSimplification void AlgebraicSimplification::runOnOperation() { RewritePatternSet patterns(&getContext()); populateMathAlgebraicSimplificationPatterns(patterns); - (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns), - config); + (void)applyPatternsGreedily(getOperation(), std::move(patterns), config); } std::unique_ptr<mlir::Pass> fir::createAlgebraicSimplificationPass() { diff --git a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp index 2c9c73e8a539..eb59045a5fde 100644 --- a/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/AssumedRankOpConversion.cpp @@ -154,7 +154,7 @@ public: mlir::GreedyRewriteConfig config; config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; - (void)applyPatternsAndFoldGreedily(mod, std::move(patterns), config); + (void)applyPatternsGreedily(mod, std::move(patterns), config); } }; } // namespace diff --git a/flang/lib/Optimizer/Transforms/CMakeLists.txt b/flang/lib/Optimizer/Transforms/CMakeLists.txt index 9eafa4ec234b..d20d3bc4108c 100644 --- a/flang/lib/Optimizer/Transforms/CMakeLists.txt +++ b/flang/lib/Optimizer/Transforms/CMakeLists.txt @@ -9,7 +9,6 @@ add_flang_library(FIRTransforms CompilerGeneratedNames.cpp ConstantArgumentGlobalisation.cpp ControlFlowConverter.cpp - CUFCommon.cpp CUFAddConstructor.cpp CUFDeviceGlobal.cpp CUFOpConversion.cpp diff --git a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp index 73a46843f032..97551595db03 100644 --- a/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp +++ b/flang/lib/Optimizer/Transforms/CUFAddConstructor.cpp @@ -7,6 +7,7 @@ //===----------------------------------------------------------------------===// #include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/Builder/Todo.h" @@ -19,7 +20,6 @@ #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Support/DataLayout.h" -#include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Runtime/CUDA/registration.h" #include "flang/Runtime/entry-names.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" @@ -106,7 +106,8 @@ struct CUFAddConstructor mlir::func::FuncOp func; switch (attr.getValue()) { - case cuf::DataAttribute::Device: { + case cuf::DataAttribute::Device: + case cuf::DataAttribute::Constant: { func = fir::runtime::getRuntimeFunc<mkRTKey(CUFRegisterVariable)>( loc, builder); auto fTy = func.getFunctionType(); diff --git a/flang/lib/Optimizer/Transforms/CUFCommon.cpp b/flang/lib/Optimizer/Transforms/CUFCommon.cpp deleted file mode 100644 index 5b7631bbacb5..000000000000 --- a/flang/lib/Optimizer/Transforms/CUFCommon.cpp +++ /dev/null @@ -1,45 +0,0 @@ -//===-- CUFCommon.cpp - Shared functions between passes ---------*- C++ -*-===// -// -// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. -// See https://llvm.org/LICENSE.txt for license information. -// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception -// -//===----------------------------------------------------------------------===// - -#include "flang/Optimizer/Transforms/CUFCommon.h" -#include "flang/Optimizer/Dialect/CUF/CUFOps.h" -#include "mlir/Dialect/Func/IR/FuncOps.h" -#include "mlir/Dialect/LLVMIR/NVVMDialect.h" - -/// Retrieve or create the CUDA Fortran GPU module in the give in \p mod. -mlir::gpu::GPUModuleOp cuf::getOrCreateGPUModule(mlir::ModuleOp mod, - mlir::SymbolTable &symTab) { - if (auto gpuMod = symTab.lookup<mlir::gpu::GPUModuleOp>(cudaDeviceModuleName)) - return gpuMod; - - auto *ctx = mod.getContext(); - mod->setAttr(mlir::gpu::GPUDialect::getContainerModuleAttrName(), - mlir::UnitAttr::get(ctx)); - - mlir::OpBuilder builder(ctx); - auto gpuMod = builder.create<mlir::gpu::GPUModuleOp>(mod.getLoc(), - cudaDeviceModuleName); - mlir::Block::iterator insertPt(mod.getBodyRegion().front().end()); - symTab.insert(gpuMod, insertPt); - return gpuMod; -} - -bool cuf::isInCUDADeviceContext(mlir::Operation *op) { - if (!op) - return false; - if (op->getParentOfType<cuf::KernelOp>() || - op->getParentOfType<mlir::gpu::GPUFuncOp>()) - return true; - if (auto funcOp = op->getParentOfType<mlir::func::FuncOp>()) { - if (auto cudaProcAttr = funcOp->getAttrOfType<cuf::ProcAttributeAttr>( - cuf::getProcAttrName())) { - return cudaProcAttr.getValue() != cuf::ProcAttribute::Host; - } - } - return false; -} diff --git a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp index 714b0b291be1..2e6c272fa908 100644 --- a/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp +++ b/flang/lib/Optimizer/Transforms/CUFDeviceGlobal.cpp @@ -7,17 +7,19 @@ //===----------------------------------------------------------------------===// #include "flang/Common/Fortran.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" #include "flang/Optimizer/Dialect/FIRDialect.h" #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" -#include "flang/Optimizer/Transforms/CUFCommon.h" +#include "flang/Optimizer/Support/InternalNames.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/allocatable.h" #include "mlir/Dialect/LLVMIR/NVVMDialect.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" #include "mlir/Transforms/DialectConversion.h" +#include "llvm/ADT/DenseSet.h" namespace fir { #define GEN_PASS_DEF_CUFDEVICEGLOBAL @@ -27,36 +29,53 @@ namespace fir { namespace { static void processAddrOfOp(fir::AddrOfOp addrOfOp, - mlir::SymbolTable &symbolTable, bool onlyConstant) { + mlir::SymbolTable &symbolTable, + llvm::DenseSet<fir::GlobalOp> &candidates, + bool recurseInGlobal) { if (auto globalOp = symbolTable.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { - bool isCandidate{(onlyConstant ? globalOp.getConstant() : true) && - !globalOp.getDataAttr()}; - if (isCandidate) - globalOp.setDataAttrAttr(cuf::DataAttributeAttr::get( - addrOfOp.getContext(), globalOp.getConstant() - ? cuf::DataAttribute::Constant - : cuf::DataAttribute::Device)); + // TO DO: limit candidates to non-scalars. Scalars appear to have been + // folded in already. + if (globalOp.getConstant()) { + if (recurseInGlobal) + globalOp.walk([&](fir::AddrOfOp op) { + processAddrOfOp(op, symbolTable, candidates, recurseInGlobal); + }); + candidates.insert(globalOp); + } + } +} + +static void processEmboxOp(fir::EmboxOp emboxOp, mlir::SymbolTable &symbolTable, + llvm::DenseSet<fir::GlobalOp> &candidates) { + if (auto recTy = mlir::dyn_cast<fir::RecordType>( + fir::unwrapRefType(emboxOp.getMemref().getType()))) { + if (auto globalOp = symbolTable.lookup<fir::GlobalOp>( + fir::NameUniquer::getTypeDescriptorName(recTy.getName()))) { + if (!candidates.contains(globalOp)) { + globalOp.walk([&](fir::AddrOfOp op) { + processAddrOfOp(op, symbolTable, candidates, + /*recurseInGlobal=*/true); + }); + candidates.insert(globalOp); + } + } } } -static void prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp, - mlir::SymbolTable &symbolTable, - bool onlyConstant = true) { +static void +prepareImplicitDeviceGlobals(mlir::func::FuncOp funcOp, + mlir::SymbolTable &symbolTable, + llvm::DenseSet<fir::GlobalOp> &candidates) { auto cudaProcAttr{ funcOp->getAttrOfType<cuf::ProcAttributeAttr>(cuf::getProcAttrName())}; - if (!cudaProcAttr || cudaProcAttr.getValue() == cuf::ProcAttribute::Host) { - // Look for globlas in CUF KERNEL DO operations. - for (auto cufKernelOp : funcOp.getBody().getOps<cuf::KernelOp>()) { - cufKernelOp.walk([&](fir::AddrOfOp addrOfOp) { - processAddrOfOp(addrOfOp, symbolTable, onlyConstant); - }); - } - return; + if (cudaProcAttr && cudaProcAttr.getValue() != cuf::ProcAttribute::Host) { + funcOp.walk([&](fir::AddrOfOp op) { + processAddrOfOp(op, symbolTable, candidates, /*recurseInGlobal=*/false); + }); + funcOp.walk( + [&](fir::EmboxOp op) { processEmboxOp(op, symbolTable, candidates); }); } - funcOp.walk([&](fir::AddrOfOp addrOfOp) { - processAddrOfOp(addrOfOp, symbolTable, onlyConstant); - }); } class CUFDeviceGlobal : public fir::impl::CUFDeviceGlobalBase<CUFDeviceGlobal> { @@ -67,11 +86,18 @@ public: if (!mod) return signalPassFailure(); + llvm::DenseSet<fir::GlobalOp> candidates; mlir::SymbolTable symTable(mod); mod.walk([&](mlir::func::FuncOp funcOp) { - prepareImplicitDeviceGlobals(funcOp, symTable); + prepareImplicitDeviceGlobals(funcOp, symTable, candidates); return mlir::WalkResult::advance(); }); + mod.walk([&](cuf::KernelOp kernelOp) { + kernelOp.walk([&](fir::AddrOfOp addrOfOp) { + processAddrOfOp(addrOfOp, symTable, candidates, + /*recurseInGlobal=*/false); + }); + }); // Copying the device global variable into the gpu module mlir::SymbolTable parentSymTable(mod); @@ -80,22 +106,15 @@ public: return signalPassFailure(); mlir::SymbolTable gpuSymTable(gpuMod); for (auto globalOp : mod.getOps<fir::GlobalOp>()) { - auto attr = globalOp.getDataAttrAttr(); - if (!attr) - continue; - switch (attr.getValue()) { - case cuf::DataAttribute::Device: - case cuf::DataAttribute::Constant: - case cuf::DataAttribute::Managed: { - auto globalName{globalOp.getSymbol().getValue()}; - if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) { - break; - } - gpuSymTable.insert(globalOp->clone()); - } break; - default: + if (cuf::isRegisteredDeviceGlobal(globalOp)) + candidates.insert(globalOp); + } + for (auto globalOp : candidates) { + auto globalName{globalOp.getSymbol().getValue()}; + if (gpuSymTable.lookup<fir::GlobalOp>(globalName)) { break; } + gpuSymTable.insert(globalOp->clone()); } } }; diff --git a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp index c64f35542a6e..60aa401e1cc8 100644 --- a/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFGPUToLLVMConversion.cpp @@ -42,6 +42,8 @@ static mlir::Value createKernelArgArray(mlir::Location loc, auto structTy = mlir::LLVM::LLVMStructType::getLiteral(ctx, structTypes); auto ptrTy = mlir::LLVM::LLVMPointerType::get(rewriter.getContext()); mlir::Type i32Ty = rewriter.getI32Type(); + auto zero = rewriter.create<mlir::LLVM::ConstantOp>( + loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 0)); auto one = rewriter.create<mlir::LLVM::ConstantOp>( loc, i32Ty, rewriter.getIntegerAttr(i32Ty, 1)); mlir::Value argStruct = @@ -55,10 +57,11 @@ static mlir::Value createKernelArgArray(mlir::Location loc, auto indice = rewriter.create<mlir::LLVM::ConstantOp>( loc, i32Ty, rewriter.getIntegerAttr(i32Ty, i)); mlir::Value structMember = rewriter.create<LLVM::GEPOp>( - loc, ptrTy, structTy, argStruct, mlir::ArrayRef<mlir::Value>({indice})); + loc, ptrTy, structTy, argStruct, + mlir::ArrayRef<mlir::Value>({zero, indice})); rewriter.create<LLVM::StoreOp>(loc, arg, structMember); mlir::Value arrayMember = rewriter.create<LLVM::GEPOp>( - loc, ptrTy, structTy, argArray, mlir::ArrayRef<mlir::Value>({indice})); + loc, ptrTy, ptrTy, argArray, mlir::ArrayRef<mlir::Value>({indice})); rewriter.create<LLVM::StoreOp>(loc, structMember, arrayMember); } return argArray; diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 7f6843d66d39..8c525fc6daff 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -8,6 +8,7 @@ #include "flang/Optimizer/Transforms/CUFOpConversion.h" #include "flang/Common/Fortran.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" @@ -15,7 +16,6 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" -#include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Runtime/CUDA/allocatable.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/CUDA/descriptor.h" @@ -81,15 +81,6 @@ static bool hasDoubleDescriptors(OpTy op) { return false; } -bool isDeviceGlobal(fir::GlobalOp op) { - auto attr = op.getDataAttr(); - if (attr && (*attr == cuf::DataAttribute::Device || - *attr == cuf::DataAttribute::Managed || - *attr == cuf::DataAttribute::Constant)) - return true; - return false; -} - static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { @@ -351,7 +342,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> { // Convert descriptor allocations to function call. auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType()); mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder); + fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); @@ -388,7 +379,7 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { if (auto global = symTab.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { - if (isDeviceGlobal(global)) { + if (cuf::isRegisteredDeviceGlobal(global)) { rewriter.setInsertionPointAfter(addrOfOp); auto mod = op->getParentOfType<mlir::ModuleOp>(); fir::FirOpBuilder builder(rewriter, mod); @@ -460,7 +451,7 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> { // Convert cuf.free on descriptors. mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder); + fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); @@ -797,6 +788,38 @@ private: const mlir::SymbolTable &symTab; }; +struct CUFSyncDescriptorOpConversion + : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cuf::SyncDescriptorOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName()); + if (!globalOp) + return mlir::failure(); + + auto hostAddr = builder.create<fir::AddrOfOp>( + loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName()); + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc, + builder); + auto fTy = callee.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, hostAddr, sourceFile, sourceLine)}; + builder.create<fir::CallOp>(loc, callee, args); + op.erase(); + return mlir::success(); + } +}; + class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { public: void runOnOperation() override { @@ -833,7 +856,7 @@ public: addrOfOp.getSymbol().getRootReference().getValue())) { if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType()))) return true; - if (isDeviceGlobal(global)) + if (cuf::isRegisteredDeviceGlobal(global)) return false; } } @@ -857,7 +880,8 @@ void cuf::populateCUFToFIRConversionPatterns( const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion, - CUFFreeOpConversion>(patterns.getContext()); + CUFFreeOpConversion, CUFSyncDescriptorOpConversion>( + patterns.getContext()); patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, &dl, &converter); patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab); diff --git a/flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp b/flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp index 7f2cc41275e5..f92c60908b14 100644 --- a/flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp +++ b/flang/lib/Optimizer/Transforms/CompilerGeneratedNames.cpp @@ -11,6 +11,7 @@ #include "flang/Optimizer/Dialect/FIROpsSupport.h" #include "flang/Optimizer/Support/InternalNames.h" #include "flang/Optimizer/Transforms/Passes.h" +#include "mlir/Dialect/GPU/IR/GPUDialect.h" #include "mlir/IR/Attributes.h" #include "mlir/IR/SymbolTable.h" #include "mlir/Pass/Pass.h" @@ -42,24 +43,31 @@ void CompilerGeneratedNamesConversionPass::runOnOperation() { auto *context = &getContext(); llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings; - for (auto &funcOrGlobal : op->getRegion(0).front()) { - if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) || - llvm::isa<fir::GlobalOp>(funcOrGlobal)) { - auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>( - mlir::SymbolTable::getSymbolAttrName()); - auto deconstructedName = fir::NameUniquer::deconstruct(symName); - if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED && - !fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { - std::string newName = - fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str()); - if (newName != symName) { - auto newAttr = mlir::StringAttr::get(context, newName); - mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr); - auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr); - remappings.try_emplace(symName, newSymRef); - } + + auto processOp = [&](mlir::Operation &op) { + auto symName = op.getAttrOfType<mlir::StringAttr>( + mlir::SymbolTable::getSymbolAttrName()); + auto deconstructedName = fir::NameUniquer::deconstruct(symName); + if (deconstructedName.first != fir::NameUniquer::NameKind::NOT_UNIQUED && + !fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { + std::string newName = + fir::NameUniquer::replaceSpecialSymbols(symName.getValue().str()); + if (newName != symName) { + auto newAttr = mlir::StringAttr::get(context, newName); + mlir::SymbolTable::setSymbolName(&op, newAttr); + auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr); + remappings.try_emplace(symName, newSymRef); } } + }; + for (auto &op : op->getRegion(0).front()) { + if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op)) + processOp(op); + else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(&op)) + for (auto &op : gpuMod->getRegion(0).front()) + if (llvm::isa<mlir::func::FuncOp>(op) || llvm::isa<fir::GlobalOp>(op) || + llvm::isa<mlir::gpu::GPUFuncOp>(op)) + processOp(op); } if (remappings.empty()) diff --git a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp index eef6f047fc1b..562f3058f20f 100644 --- a/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp +++ b/flang/lib/Optimizer/Transforms/ConstantArgumentGlobalisation.cpp @@ -173,8 +173,8 @@ public: config.strictMode = mlir::GreedyRewriteStrictness::ExistingOps; patterns.insert<CallOpRewriter>(context, *di); - if (mlir::failed(mlir::applyPatternsAndFoldGreedily( - mod, std::move(patterns), config))) { + if (mlir::failed( + mlir::applyPatternsGreedily(mod, std::move(patterns), config))) { mlir::emitError(mod.getLoc(), "error in constant globalisation optimization\n"); signalPassFailure(); diff --git a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp index 3b79d6d311b7..b09bbf6106db 100644 --- a/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp +++ b/flang/lib/Optimizer/Transforms/ControlFlowConverter.cpp @@ -332,8 +332,6 @@ class CfgConversion : public fir::impl::CFGConversionBase<CfgConversion> { public: using CFGConversionBase<CfgConversion>::CFGConversionBase; - CfgConversion(bool setNSW) { this->setNSW = setNSW; } - void runOnOperation() override { auto *context = &this->getContext(); mlir::RewritePatternSet patterns(context); @@ -365,7 +363,3 @@ void fir::populateCfgConversionRewrites(mlir::RewritePatternSet &patterns, patterns.insert<CfgLoopConv, CfgIfConv, CfgIterWhileConv>( patterns.getContext(), forceLoopToExecuteOnce, setNSW); } - -std::unique_ptr<mlir::Pass> fir::createCFGConversionPassWithNSW() { - return std::make_unique<CfgConversion>(true); -} diff --git a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp index cc99698ead33..8ae3d313d881 100644 --- a/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp +++ b/flang/lib/Optimizer/Transforms/DebugTypeGenerator.cpp @@ -325,7 +325,7 @@ static bool canCacheThisType(mlir::LLVM::DICompositeTypeAttr comTy) { std::pair<std::uint64_t, unsigned short> DebugTypeGenerator::getFieldSizeAndAlign(mlir::Type fieldTy) { mlir::Type llvmTy; - if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(fieldTy)) + if (auto boxTy = mlir::dyn_cast_if_present<fir::BaseBoxType>(fieldTy)) llvmTy = llvmTypeConverter.convertBoxTypeAsStruct(boxTy, getBoxRank(boxTy)); else llvmTy = llvmTypeConverter.convertType(fieldTy); @@ -371,7 +371,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertRecordType( std::optional<llvm::ArrayRef<int64_t>> lowerBounds = fir::getComponentLowerBoundsIfNonDefault(Ty, fieldName, module, symbolTable); - auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(fieldTy); + auto seqTy = mlir::dyn_cast_if_present<fir::SequenceType>(fieldTy); // For members of the derived types, the information about the shift in // lower bounds is not part of the declOp but has to be extracted from the @@ -622,10 +622,10 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertPointerLikeType( // Arrays and character need different treatment because DWARF have special // constructs for them to get the location from the descriptor. Rest of // types are handled like pointer to underlying type. - if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(elTy)) + if (auto seqTy = mlir::dyn_cast_if_present<fir::SequenceType>(elTy)) return convertBoxedSequenceType(seqTy, fileAttr, scope, declOp, genAllocated, genAssociated); - if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(elTy)) + if (auto charTy = mlir::dyn_cast_if_present<fir::CharacterType>(elTy)) return convertCharacterType(charTy, fileAttr, scope, declOp, /*hasDescriptor=*/true); @@ -638,7 +638,7 @@ mlir::LLVM::DITypeAttr DebugTypeGenerator::convertPointerLikeType( return mlir::LLVM::DIDerivedTypeAttr::get( context, llvm::dwarf::DW_TAG_pointer_type, - mlir::StringAttr::get(context, ""), elTyAttr, ptrSize, + mlir::StringAttr::get(context, ""), elTyAttr, /*sizeInBits=*/ptrSize * 8, /*alignInBits=*/0, /*offset=*/0, /*optional<address space>=*/std::nullopt, /*extra data=*/nullptr); } @@ -654,22 +654,22 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr, } else if (mlir::isa<mlir::FloatType>(Ty)) { return genBasicType(context, mlir::StringAttr::get(context, "real"), Ty.getIntOrFloatBitWidth(), llvm::dwarf::DW_ATE_float); - } else if (auto logTy = mlir::dyn_cast_or_null<fir::LogicalType>(Ty)) { + } else if (auto logTy = mlir::dyn_cast_if_present<fir::LogicalType>(Ty)) { return genBasicType(context, mlir::StringAttr::get(context, logTy.getMnemonic()), kindMapping.getLogicalBitsize(logTy.getFKind()), llvm::dwarf::DW_ATE_boolean); - } else if (auto cplxTy = mlir::dyn_cast_or_null<mlir::ComplexType>(Ty)) { + } else if (auto cplxTy = mlir::dyn_cast_if_present<mlir::ComplexType>(Ty)) { auto floatTy = mlir::cast<mlir::FloatType>(cplxTy.getElementType()); unsigned bitWidth = floatTy.getWidth(); return genBasicType(context, mlir::StringAttr::get(context, "complex"), bitWidth * 2, llvm::dwarf::DW_ATE_complex_float); - } else if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(Ty)) { + } else if (auto seqTy = mlir::dyn_cast_if_present<fir::SequenceType>(Ty)) { return convertSequenceType(seqTy, fileAttr, scope, declOp); - } else if (auto charTy = mlir::dyn_cast_or_null<fir::CharacterType>(Ty)) { + } else if (auto charTy = mlir::dyn_cast_if_present<fir::CharacterType>(Ty)) { return convertCharacterType(charTy, fileAttr, scope, declOp, /*hasDescriptor=*/false); - } else if (auto recTy = mlir::dyn_cast_or_null<fir::RecordType>(Ty)) { + } else if (auto recTy = mlir::dyn_cast_if_present<fir::RecordType>(Ty)) { return convertRecordType(recTy, fileAttr, scope, declOp); } else if (auto tupleTy = mlir::dyn_cast_if_present<mlir::TupleType>(Ty)) { return convertTupleType(tupleTy, fileAttr, scope, declOp); @@ -678,22 +678,22 @@ DebugTypeGenerator::convertType(mlir::Type Ty, mlir::LLVM::DIFileAttr fileAttr, return convertPointerLikeType(elTy, fileAttr, scope, declOp, /*genAllocated=*/false, /*genAssociated=*/false); - } else if (auto vecTy = mlir::dyn_cast_or_null<fir::VectorType>(Ty)) { + } else if (auto vecTy = mlir::dyn_cast_if_present<fir::VectorType>(Ty)) { return convertVectorType(vecTy, fileAttr, scope, declOp); } else if (mlir::isa<mlir::IndexType>(Ty)) { return genBasicType(context, mlir::StringAttr::get(context, "integer"), llvmTypeConverter.getIndexTypeBitwidth(), llvm::dwarf::DW_ATE_signed); - } else if (auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(Ty)) { + } else if (auto boxTy = mlir::dyn_cast_if_present<fir::BaseBoxType>(Ty)) { auto elTy = boxTy.getEleTy(); - if (auto seqTy = mlir::dyn_cast_or_null<fir::SequenceType>(elTy)) + if (auto seqTy = mlir::dyn_cast_if_present<fir::SequenceType>(elTy)) return convertBoxedSequenceType(seqTy, fileAttr, scope, declOp, false, false); - if (auto heapTy = mlir::dyn_cast_or_null<fir::HeapType>(elTy)) + if (auto heapTy = mlir::dyn_cast_if_present<fir::HeapType>(elTy)) return convertPointerLikeType(heapTy.getElementType(), fileAttr, scope, declOp, /*genAllocated=*/true, /*genAssociated=*/false); - if (auto ptrTy = mlir::dyn_cast_or_null<fir::PointerType>(elTy)) + if (auto ptrTy = mlir::dyn_cast_if_present<fir::PointerType>(elTy)) return convertPointerLikeType(ptrTy.getElementType(), fileAttr, scope, declOp, /*genAllocated=*/false, /*genAssociated=*/true); diff --git a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp index cfd90ff72379..4f6974ee5269 100644 --- a/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp +++ b/flang/lib/Optimizer/Transforms/ExternalNameConversion.cpp @@ -60,23 +60,30 @@ void ExternalNameConversionPass::runOnOperation() { llvm::DenseMap<mlir::StringAttr, mlir::FlatSymbolRefAttr> remappings; + auto processFctOrGlobal = [&](mlir::Operation &funcOrGlobal) { + auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>( + mlir::SymbolTable::getSymbolAttrName()); + auto deconstructedName = fir::NameUniquer::deconstruct(symName); + if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { + auto newName = mangleExternalName(deconstructedName, appendUnderscoreOpt); + auto newAttr = mlir::StringAttr::get(context, newName); + mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr); + auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr); + remappings.try_emplace(symName, newSymRef); + if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal)) + funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName); + } + }; + auto renameFuncOrGlobalInModule = [&](mlir::Operation *module) { - for (auto &funcOrGlobal : module->getRegion(0).front()) { - if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal) || - llvm::isa<fir::GlobalOp>(funcOrGlobal)) { - auto symName = funcOrGlobal.getAttrOfType<mlir::StringAttr>( - mlir::SymbolTable::getSymbolAttrName()); - auto deconstructedName = fir::NameUniquer::deconstruct(symName); - if (fir::NameUniquer::isExternalFacingUniquedName(deconstructedName)) { - auto newName = - mangleExternalName(deconstructedName, appendUnderscoreOpt); - auto newAttr = mlir::StringAttr::get(context, newName); - mlir::SymbolTable::setSymbolName(&funcOrGlobal, newAttr); - auto newSymRef = mlir::FlatSymbolRefAttr::get(newAttr); - remappings.try_emplace(symName, newSymRef); - if (llvm::isa<mlir::func::FuncOp>(funcOrGlobal)) - funcOrGlobal.setAttr(fir::getInternalFuncNameAttrName(), symName); - } + for (auto &op : module->getRegion(0).front()) { + if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp>(op)) { + processFctOrGlobal(op); + } else if (auto gpuMod = mlir::dyn_cast<mlir::gpu::GPUModuleOp>(op)) { + for (auto &gpuOp : gpuMod.getBodyRegion().front()) + if (mlir::isa<mlir::func::FuncOp, fir::GlobalOp, + mlir::gpu::GPUFuncOp>(gpuOp)) + processFctOrGlobal(gpuOp); } } }; @@ -85,11 +92,6 @@ void ExternalNameConversionPass::runOnOperation() { // globals. renameFuncOrGlobalInModule(op); - // Do the same in GPU modules. - if (auto mod = mlir::dyn_cast_or_null<mlir::ModuleOp>(*op)) - for (auto gpuMod : mod.getOps<mlir::gpu::GPUModuleOp>()) - renameFuncOrGlobalInModule(gpuMod); - if (remappings.empty()) return; @@ -97,11 +99,18 @@ void ExternalNameConversionPass::runOnOperation() { op.walk([&remappings](mlir::Operation *nestedOp) { llvm::SmallVector<std::pair<mlir::StringAttr, mlir::SymbolRefAttr>> updates; for (const mlir::NamedAttribute &attr : nestedOp->getAttrDictionary()) - if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) - if (auto remap = remappings.find(symRef.getRootReference()); - remap != remappings.end()) + if (auto symRef = llvm::dyn_cast<mlir::SymbolRefAttr>(attr.getValue())) { + if (auto remap = remappings.find(symRef.getLeafReference()); + remap != remappings.end()) { + mlir::SymbolRefAttr symAttr = mlir::FlatSymbolRefAttr(remap->second); + if (mlir::isa<mlir::gpu::LaunchFuncOp>(nestedOp)) + symAttr = mlir::SymbolRefAttr::get( + symRef.getRootReference(), + {mlir::FlatSymbolRefAttr(remap->second)}); updates.emplace_back(std::pair<mlir::StringAttr, mlir::SymbolRefAttr>{ - attr.getName(), mlir::SymbolRefAttr(remap->second)}); + attr.getName(), symAttr}); + } + } for (auto update : updates) nestedOp->setAttr(update.first, update.second); }); diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp index adc39861840a..b534ec160ce2 100644 --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -145,11 +145,45 @@ struct ArgsUsageInLoop { }; } // namespace -static fir::SequenceType getAsSequenceType(mlir::Value *v) { - mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType())); +static fir::SequenceType getAsSequenceType(mlir::Value v) { + mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType())); return mlir::dyn_cast<fir::SequenceType>(argTy); } +/// Return the rank and the element size (in bytes) of the given +/// value \p v. If it is not an array or the element type is not +/// supported, then return <0, 0>. Only trivial data types +/// are currently supported. +/// When \p isArgument is true, \p v is assumed to be a function +/// argument. If \p v's type does not look like a type of an assumed +/// shape array, then the function returns <0, 0>. +/// When \p isArgument is false, array types with known innermost +/// dimension are allowed to proceed. +static std::pair<unsigned, size_t> +getRankAndElementSize(const fir::KindMapping &kindMap, + const mlir::DataLayout &dl, mlir::Value v, + bool isArgument = false) { + if (auto seqTy = getAsSequenceType(v)) { + unsigned rank = seqTy.getDimension(); + if (rank > 0 && + (!isArgument || + seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) { + size_t typeSize = 0; + mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType()); + if (fir::isa_trivial(elementType)) { + auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash( + v.getLoc(), elementType, dl, kindMap); + typeSize = llvm::alignTo(eleSize, eleAlign); + } + if (typeSize) + return {rank, typeSize}; + } + } + + LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n'); + return {0, 0}; +} + /// if a value comes from a fir.declare, follow it to the original source, /// otherwise return the value static mlir::Value unwrapFirDeclare(mlir::Value val) { @@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) { return val; } +/// Return true, if \p rebox operation keeps the input array +/// continuous in the innermost dimension, if it is initially continuous +/// in the innermost dimension. +static bool reboxPreservesContinuity(fir::ReboxOp rebox) { + // If slicing is not involved, then the rebox does not affect + // the continuity of the array. + auto sliceArg = rebox.getSlice(); + if (!sliceArg) + return true; + + // A slice with step=1 in the innermost dimension preserves + // the continuity of the array in the innermost dimension. + if (auto sliceOp = + mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) { + if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) { + auto triples = sliceOp.getTriples(); + if (triples.size() > 2) + if (auto innermostStep = fir::getIntIfConstant(triples[2])) + if (*innermostStep == 1) + return true; + } + + LLVM_DEBUG(llvm::dbgs() + << "REBOX with slicing may produce non-contiguous array: " + << sliceOp << '\n' + << rebox << '\n'); + return false; + } + + LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n' + << rebox << '\n'); + return false; +} + /// if a value comes from a fir.rebox, follow the rebox to the original source, /// of the value, otherwise return the value static mlir::Value unwrapReboxOp(mlir::Value val) { - // don't support reboxes of reboxes - if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) + while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) { + if (!reboxPreservesContinuity(rebox)) + break; val = rebox.getBox(); + } return val; } @@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() { continue; } - if (auto seqTy = getAsSequenceType(&arg)) { - unsigned rank = seqTy.getDimension(); - if (rank > 0 && - seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) { - size_t typeSize = 0; - mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType()); - if (mlir::isa<mlir::FloatType>(elementType) || - mlir::isa<mlir::IntegerType>(elementType) || - mlir::isa<mlir::ComplexType>(elementType)) { - auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash( - arg.getLoc(), elementType, *dl, kindMap); - typeSize = llvm::alignTo(eleSize, eleAlign); - } - if (typeSize) - argsOfInterest.push_back({arg, typeSize, rank, {}}); - else - LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); - } - } + auto [rank, typeSize] = + getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true); + if (rank != 0 && typeSize != 0) + argsOfInterest.push_back({arg, typeSize, rank, {}}); } if (argsOfInterest.empty()) { @@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() { if (arrayCoor.getSlice()) argsInLoop.cannotTransform.insert(a.arg); + // We need to compute the rank and element size + // based on the operand, not the original argument, + // because array slicing may affect it. + std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg); + if (a.rank == 0 || a.size == 0) + argsInLoop.cannotTransform.insert(a.arg); + if (argsInLoop.cannotTransform.contains(a.arg)) { // Remove any previously recorded usage, if any. argsInLoop.usageInfo.erase(a.arg); @@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() { mlir::Location loc = builder.getUnknownLoc(); mlir::IndexType idxTy = builder.getIndexType(); - LLVM_DEBUG(llvm::dbgs() << "Module Before transformation:"); - LLVM_DEBUG(module->dump()); + LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n"); + LLVM_DEBUG(func->dump()); LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size() << "\n"); @@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() { } } - LLVM_DEBUG(llvm::dbgs() << "After transform:\n"); - LLVM_DEBUG(module->dump()); + LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n"); + LLVM_DEBUG(func->dump()); LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); } diff --git a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp index d3567f453fce..fa6a7b23624e 100644 --- a/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp +++ b/flang/lib/Optimizer/Transforms/SimplifyIntrinsics.cpp @@ -24,6 +24,7 @@ #include "flang/Common/Fortran.h" #include "flang/Optimizer/Builder/BoxValue.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/FIRBuilder.h" #include "flang/Optimizer/Builder/LowLevelIntrinsics.h" #include "flang/Optimizer/Builder/Todo.h" @@ -31,7 +32,6 @@ #include "flang/Optimizer/Dialect/FIRType.h" #include "flang/Optimizer/Dialect/Support/FIRContext.h" #include "flang/Optimizer/HLFIR/HLFIRDialect.h" -#include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Optimizer/Transforms/Passes.h" #include "flang/Optimizer/Transforms/Utils.h" #include "flang/Runtime/entry-names.h" diff --git a/flang/lib/Optimizer/Transforms/StackArrays.cpp b/flang/lib/Optimizer/Transforms/StackArrays.cpp index 0c474f463f09..bdcb8199b790 100644 --- a/flang/lib/Optimizer/Transforms/StackArrays.cpp +++ b/flang/lib/Optimizer/Transforms/StackArrays.cpp @@ -76,8 +76,9 @@ class InsertionPoint { /// Get contained pointer type or nullptr template <class T> T *tryGetPtr() const { - if (location.is<T *>()) - return location.get<T *>(); + // Use llvm::dyn_cast_if_present because location may be null here. + if (T *ptr = llvm::dyn_cast_if_present<T *>(location)) + return ptr; return nullptr; } @@ -793,8 +794,8 @@ void StackArraysPass::runOnOperation() { config.enableRegionSimplification = mlir::GreedySimplifyRegionLevel::Disabled; patterns.insert<AllocMemConversion>(&context, *candidateOps); - if (mlir::failed(mlir::applyOpPatternsAndFold(opsToConvert, - std::move(patterns), config))) { + if (mlir::failed(mlir::applyOpPatternsGreedily( + opsToConvert, std::move(patterns), config))) { mlir::emitError(func->getLoc(), "error in stack arrays optimization\n"); signalPassFailure(); } |
