diff options
| author | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
|---|---|---|
| committer | Michael Kruse <llvm-project@meinersbur.de> | 2025-01-03 10:22:51 +0100 |
| commit | 38500d63e14ce340236840f60d356cdefb56a52c (patch) | |
| tree | 17edbec446ce9b50d2f215a483b83afb293a635d /flang/lib/Optimizer/Transforms/CUFOpConversion.cpp | |
| parent | 1a3d5daaef7a6a63448a497da3eff7fc9e23df26 (diff) | |
| parent | 27f30029741ecf023baece7b3dde1ff9011ffefc (diff) | |
Merge branch 'main' into users/meinersbur/flang_runtime_split-headersusers/meinersbur/flang_runtime_split-headers
Diffstat (limited to 'flang/lib/Optimizer/Transforms/CUFOpConversion.cpp')
| -rw-r--r-- | flang/lib/Optimizer/Transforms/CUFOpConversion.cpp | 54 |
1 files changed, 39 insertions, 15 deletions
diff --git a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp index 7f6843d66d39..8c525fc6daff 100644 --- a/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp +++ b/flang/lib/Optimizer/Transforms/CUFOpConversion.cpp @@ -8,6 +8,7 @@ #include "flang/Optimizer/Transforms/CUFOpConversion.h" #include "flang/Common/Fortran.h" +#include "flang/Optimizer/Builder/CUFCommon.h" #include "flang/Optimizer/Builder/Runtime/RTBuilder.h" #include "flang/Optimizer/CodeGen/TypeConverter.h" #include "flang/Optimizer/Dialect/CUF/CUFOps.h" @@ -15,7 +16,6 @@ #include "flang/Optimizer/Dialect/FIROps.h" #include "flang/Optimizer/HLFIR/HLFIROps.h" #include "flang/Optimizer/Support/DataLayout.h" -#include "flang/Optimizer/Transforms/CUFCommon.h" #include "flang/Runtime/CUDA/allocatable.h" #include "flang/Runtime/CUDA/common.h" #include "flang/Runtime/CUDA/descriptor.h" @@ -81,15 +81,6 @@ static bool hasDoubleDescriptors(OpTy op) { return false; } -bool isDeviceGlobal(fir::GlobalOp op) { - auto attr = op.getDataAttr(); - if (attr && (*attr == cuf::DataAttribute::Device || - *attr == cuf::DataAttribute::Managed || - *attr == cuf::DataAttribute::Constant)) - return true; - return false; -} - static mlir::Value createConvertOp(mlir::PatternRewriter &rewriter, mlir::Location loc, mlir::Type toTy, mlir::Value val) { @@ -351,7 +342,7 @@ struct CUFAllocOpConversion : public mlir::OpRewritePattern<cuf::AllocOp> { // Convert descriptor allocations to function call. auto boxTy = mlir::dyn_cast_or_null<fir::BaseBoxType>(op.getInType()); mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDesciptor)>(loc, builder); + fir::runtime::getRuntimeFunc<mkRTKey(CUFAllocDescriptor)>(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); @@ -388,7 +379,7 @@ struct DeclareOpConversion : public mlir::OpRewritePattern<fir::DeclareOp> { if (auto addrOfOp = op.getMemref().getDefiningOp<fir::AddrOfOp>()) { if (auto global = symTab.lookup<fir::GlobalOp>( addrOfOp.getSymbol().getRootReference().getValue())) { - if (isDeviceGlobal(global)) { + if (cuf::isRegisteredDeviceGlobal(global)) { rewriter.setInsertionPointAfter(addrOfOp); auto mod = op->getParentOfType<mlir::ModuleOp>(); fir::FirOpBuilder builder(rewriter, mod); @@ -460,7 +451,7 @@ struct CUFFreeOpConversion : public mlir::OpRewritePattern<cuf::FreeOp> { // Convert cuf.free on descriptors. mlir::func::FuncOp func = - fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDesciptor)>(loc, builder); + fir::runtime::getRuntimeFunc<mkRTKey(CUFFreeDescriptor)>(loc, builder); auto fTy = func.getFunctionType(); mlir::Value sourceLine = fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); @@ -797,6 +788,38 @@ private: const mlir::SymbolTable &symTab; }; +struct CUFSyncDescriptorOpConversion + : public mlir::OpRewritePattern<cuf::SyncDescriptorOp> { + using OpRewritePattern::OpRewritePattern; + + mlir::LogicalResult + matchAndRewrite(cuf::SyncDescriptorOp op, + mlir::PatternRewriter &rewriter) const override { + auto mod = op->getParentOfType<mlir::ModuleOp>(); + fir::FirOpBuilder builder(rewriter, mod); + mlir::Location loc = op.getLoc(); + + auto globalOp = mod.lookupSymbol<fir::GlobalOp>(op.getGlobalName()); + if (!globalOp) + return mlir::failure(); + + auto hostAddr = builder.create<fir::AddrOfOp>( + loc, fir::ReferenceType::get(globalOp.getType()), op.getGlobalName()); + mlir::func::FuncOp callee = + fir::runtime::getRuntimeFunc<mkRTKey(CUFSyncGlobalDescriptor)>(loc, + builder); + auto fTy = callee.getFunctionType(); + mlir::Value sourceFile = fir::factory::locationToFilename(builder, loc); + mlir::Value sourceLine = + fir::factory::locationToLineNo(builder, loc, fTy.getInput(2)); + llvm::SmallVector<mlir::Value> args{fir::runtime::createArguments( + builder, loc, fTy, hostAddr, sourceFile, sourceLine)}; + builder.create<fir::CallOp>(loc, callee, args); + op.erase(); + return mlir::success(); + } +}; + class CUFOpConversion : public fir::impl::CUFOpConversionBase<CUFOpConversion> { public: void runOnOperation() override { @@ -833,7 +856,7 @@ public: addrOfOp.getSymbol().getRootReference().getValue())) { if (mlir::isa<fir::BaseBoxType>(fir::unwrapRefType(global.getType()))) return true; - if (isDeviceGlobal(global)) + if (cuf::isRegisteredDeviceGlobal(global)) return false; } } @@ -857,7 +880,8 @@ void cuf::populateCUFToFIRConversionPatterns( const mlir::SymbolTable &symtab, mlir::RewritePatternSet &patterns) { patterns.insert<CUFAllocOpConversion>(patterns.getContext(), &dl, &converter); patterns.insert<CUFAllocateOpConversion, CUFDeallocateOpConversion, - CUFFreeOpConversion>(patterns.getContext()); + CUFFreeOpConversion, CUFSyncDescriptorOpConversion>( + patterns.getContext()); patterns.insert<CUFDataTransferOpConversion>(patterns.getContext(), symtab, &dl, &converter); patterns.insert<CUFLaunchOpConversion>(patterns.getContext(), symtab); |
