diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 165 |
1 files changed, 117 insertions, 48 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 6ec4c120c11e..eabc4b30f57a 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -30,6 +30,7 @@ #include "llvm/Frontend/OpenMP/OMPIRBuilder.h" #include "llvm/IR/DebugInfoMetadata.h" #include "llvm/IR/IRBuilder.h" +#include "llvm/IR/ReplaceConstant.h" #include "llvm/Support/FileSystem.h" #include "llvm/TargetParser/Triple.h" #include "llvm/Transforms/Utils/ModuleUtils.h" @@ -388,8 +389,19 @@ static LogicalResult inlineConvertOmpRegions( // be processed multiple times. moduleTranslation.forgetMapping(region); - if (potentialTerminator && potentialTerminator->isTerminator()) - potentialTerminator->insertAfter(&builder.GetInsertBlock()->back()); + if (potentialTerminator && potentialTerminator->isTerminator()) { + llvm::BasicBlock *block = builder.GetInsertBlock(); + if (block->empty()) { + // this can happen for really simple reduction init regions e.g. + // %0 = llvm.mlir.constant(0 : i32) : i32 + // omp.yield(%0 : i32) + // because the llvm.mlir.constant (MLIR op) isn't converted into any + // llvm op + potentialTerminator->insertInto(block, block->begin()); + } else { + potentialTerminator->insertAfter(&block->back()); + } + } return success(); } @@ -762,7 +774,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder, /// Allocate space for privatized reduction variables. template <typename T> static void allocByValReductionVars( - T loop, llvm::IRBuilderBase &builder, + T loop, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation, llvm::OpenMPIRBuilder::InsertPointTy &allocaIP, SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls, @@ -770,17 +782,15 @@ static void allocByValReductionVars( DenseMap<Value, llvm::Value *> &reductionVariableMap, llvm::ArrayRef<bool> isByRefs) { llvm::IRBuilderBase::InsertPointGuard guard(builder); - builder.restoreIP(allocaIP); - auto args = - loop.getRegion().getArguments().take_back(loop.getNumReductionVars()); + builder.SetInsertPoint(allocaIP.getBlock()->getTerminator()); for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) { if (isByRefs[i]) continue; llvm::Value *var = builder.CreateAlloca( moduleTranslation.convertType(reductionDecls[i].getType())); - moduleTranslation.mapValue(args[i], var); - privateReductionVariables.push_back(var); + moduleTranslation.mapValue(reductionArgs[i], var); + privateReductionVariables[i] = var; reductionVariableMap.try_emplace(loop.getReductionVars()[i], var); } } @@ -826,14 +836,17 @@ static void collectReductionInfo( // Collect the reduction information. reductionInfos.reserve(numReductions); for (unsigned i = 0; i < numReductions; ++i) { - llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr; + llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr; if (owningAtomicReductionGens[i]) atomicGen = owningAtomicReductionGens[i]; llvm::Value *variable = moduleTranslation.lookupValue(loop.getReductionVars()[i]); reductionInfos.push_back( {moduleTranslation.convertType(reductionDecls[i].getType()), variable, - privateReductionVariables[i], owningReductionGens[i], atomicGen}); + privateReductionVariables[i], + /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar, + owningReductionGens[i], + /*ReductionGenClang=*/nullptr, atomicGen}); } } @@ -911,16 +924,20 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, llvm::OpenMPIRBuilder::InsertPointTy allocaIP = findAllocaInsertPoint(builder, moduleTranslation); - SmallVector<llvm::Value *> privateReductionVariables; + SmallVector<llvm::Value *> privateReductionVariables( + wsloopOp.getNumReductionVars()); DenseMap<Value, llvm::Value *> reductionVariableMap; - allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP, - reductionDecls, privateReductionVariables, + + MutableArrayRef<BlockArgument> reductionArgs = + wsloopOp.getRegion().getArguments(); + + allocByValReductionVars(wsloopOp, reductionArgs, builder, moduleTranslation, + allocaIP, reductionDecls, privateReductionVariables, reductionVariableMap, isByRef); // Before the loop, store the initial values of reductions into reduction // variables. Although this could be done after allocas, we don't want to mess // up with the alloca insertion point. - ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments(); for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) { SmallVector<llvm::Value *> phis; @@ -942,7 +959,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder, // ptr builder.CreateStore(phis[0], var); - privateReductionVariables.push_back(var); + privateReductionVariables[i] = var; moduleTranslation.mapValue(reductionArgs[i], phis[0]); reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]); } else { @@ -1140,20 +1157,40 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // Collect reduction declarations SmallVector<omp::DeclareReductionOp> reductionDecls; collectReductionDecls(opInst, reductionDecls); - SmallVector<llvm::Value *> privateReductionVariables; + SmallVector<llvm::Value *> privateReductionVariables( + opInst.getNumReductionVars()); auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) { // Allocate reduction vars DenseMap<Value, llvm::Value *> reductionVariableMap; - allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP, - reductionDecls, privateReductionVariables, + + MutableArrayRef<BlockArgument> reductionArgs = + opInst.getRegion().getArguments().slice( + opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(), + opInst.getNumReductionVars()); + + allocByValReductionVars(opInst, reductionArgs, builder, moduleTranslation, + allocaIP, reductionDecls, privateReductionVariables, reductionVariableMap, isByRef); // Initialize reduction vars builder.restoreIP(allocaIP); - MutableArrayRef<BlockArgument> reductionArgs = - opInst.getRegion().getArguments().take_back( - opInst.getNumReductionVars()); + llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init"); + allocaIP = + InsertPointTy(allocaIP.getBlock(), + allocaIP.getBlock()->getTerminator()->getIterator()); + SmallVector<llvm::Value *> byRefVars(opInst.getNumReductionVars()); + for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) { + if (isByRef[i]) { + // Allocate reduction variable (which is a pointer to the real reduciton + // variable allocated in the inlined region) + byRefVars[i] = builder.CreateAlloca( + moduleTranslation.convertType(reductionDecls[i].getType())); + } + } + + builder.SetInsertPoint(initBlock->getFirstNonPHIOrDbgOrAlloca()); + for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) { SmallVector<llvm::Value *> phis; @@ -1166,18 +1203,17 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, assert(phis.size() == 1 && "expected one value to be yielded from the " "reduction neutral element declaration region"); - builder.restoreIP(allocaIP); + + // mapInitializationArg finishes its block with a terminator. We need to + // insert before that terminator. + builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator()); if (isByRef[i]) { - // Allocate reduction variable (which is a pointer to the real reduciton - // variable allocated in the inlined region) - llvm::Value *var = builder.CreateAlloca( - moduleTranslation.convertType(reductionDecls[i].getType())); // Store the result of the inlined region to the allocated reduction var // ptr - builder.CreateStore(phis[0], var); + builder.CreateStore(phis[0], byRefVars[i]); - privateReductionVariables.push_back(var); + privateReductionVariables[i] = byRefVars[i]; moduleTranslation.mapValue(reductionArgs[i], phis[0]); reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]); } else { @@ -1275,7 +1311,26 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder, // region. The privatizer is processed in-place (see below) before it // gets inlined in the parallel region and therefore processing the // original op is dangerous. - return {privVar, privatizer.clone()}; + + MLIRContext &context = moduleTranslation.getContext(); + mlir::IRRewriter opCloner(&context); + opCloner.setInsertionPoint(privatizer); + auto clone = llvm::cast<mlir::omp::PrivateClauseOp>( + opCloner.clone(*privatizer)); + + // Unique the clone name to avoid clashes in the symbol table. + unsigned counter = 0; + SmallString<256> cloneName = SymbolTable::generateSymbolName<256>( + privatizer.getSymName(), + [&](llvm::StringRef candidate) { + return SymbolTable::lookupNearestSymbolFrom( + opInst, StringAttr::get(&context, candidate)) != + nullptr; + }, + counter); + + clone.setSymName(cloneName); + return {privVar, clone}; } } @@ -1925,12 +1980,6 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, Operation *clauseOp, llvm::Value *basePointer, llvm::Type *baseType, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { - // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives - // the size in inconsistent byte or bit format. - uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type); - if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type)) - underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl); - if (auto memberClause = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) { // This calculates the size to transfer based on bounds and the underlying @@ -1956,6 +2005,12 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, } } + // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives + // the size in inconsistent byte or bit format. + uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type); + if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type)) + underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl); + // The size in bytes x number of elements, the sizeInBytes stored is // the underyling types size, e.g. if ptr<i32>, it'll be the i32's // size, so we do some on the fly runtime math to get the size in @@ -1966,7 +2021,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type, } } - return builder.getInt64(underlyingTypeSzInBits / 8); + return builder.getInt64(dl.getTypeSizeInBits(type) / 8); } void collectMapDataFromMapOperands(MapInfoData &mapData, @@ -2263,7 +2318,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers( // This creates the initial MEMBER_OF mapping that consists of // the parent/top level container (same as above effectively, except - // with a fixed initial compile time size and seperate maptype which + // with a fixed initial compile time size and separate maptype which // indicates the true mape type (tofrom etc.). This parent mapping is // only relevant if the structure in its totality is being mapped, // otherwise the above suffices. @@ -2388,7 +2443,7 @@ static void processMapWithMembersOf( // If we have a partial map (no parent referenced in the map clauses of the // directive, only members) and only a single member, we do not need to bind - // the map of the member to the parent, we can pass the member seperately. + // the map of the member to the parent, we can pass the member separately. if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) { auto memberClause = llvm::cast<mlir::omp::MapInfoOp>( parentClause.getMembers()[0].getDefiningOp()); @@ -2425,7 +2480,7 @@ createAlteredByCaptureMap(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, llvm::IRBuilderBase &builder) { for (size_t i = 0; i < mapData.MapClause.size(); ++i) { - // if it's declare target, skip it, it's handled seperately. + // if it's declare target, skip it, it's handled separately. if (!mapData.IsDeclareTarget[i]) { auto mapOp = mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]); @@ -2847,7 +2902,7 @@ static bool targetOpSupported(Operation &opInst) { static void handleDeclareTargetMapVar(MapInfoData &mapData, LLVM::ModuleTranslation &moduleTranslation, - llvm::IRBuilderBase &builder) { + llvm::IRBuilderBase &builder, llvm::Function *func) { for (size_t i = 0; i < mapData.MapClause.size(); ++i) { // In the case of declare target mapped variables, the basePointer is // the reference pointer generated by the convertDeclareTargetAttr @@ -2862,19 +2917,31 @@ handleDeclareTargetMapVar(MapInfoData &mapData, // reference pointer and the pointer are assigned in the kernel argument // structure for the host. if (mapData.IsDeclareTarget[i]) { + // If the original map value is a constant, then we have to make sure all + // of it's uses within the current kernel/function that we are going to + // rewrite are converted to instructions, as we will be altering the old + // use (OriginalValue) from a constant to an instruction, which will be + // illegal and ICE the compiler if the user is a constant expression of + // some kind e.g. a constant GEP. + if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i])) + convertUsersOfConstantsToInstructions(constant, func, false); + // The users iterator will get invalidated if we modify an element, - // so we populate this vector of uses to alter each user on an individual - // basis to emit its own load (rather than one load for all). + // so we populate this vector of uses to alter each user on an + // individual basis to emit its own load (rather than one load for + // all). llvm::SmallVector<llvm::User *> userVec; for (llvm::User *user : mapData.OriginalValue[i]->users()) userVec.push_back(user); for (llvm::User *user : userVec) { if (auto *insn = dyn_cast<llvm::Instruction>(user)) { - auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), - mapData.BasePointers[i]); - load->moveBefore(insn); - user->replaceUsesOfWith(mapData.OriginalValue[i], load); + if (insn->getFunction() == func) { + auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(), + mapData.BasePointers[i]); + load->moveBefore(insn); + user->replaceUsesOfWith(mapData.OriginalValue[i], load); + } } } } @@ -2992,6 +3059,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, auto &targetRegion = targetOp.getRegion(); DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>()); SmallVector<Value> mapOperands = targetOp.getMapOperands(); + llvm::Function *llvmOutlinedFn = nullptr; LogicalResult bodyGenStatus = success(); using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy; @@ -3001,7 +3069,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, // original function to the new outlined function. llvm::Function *llvmParentFn = moduleTranslation.lookupFunction(parentFn.getName()); - llvm::Function *llvmOutlinedFn = codeGenIP.getBlock()->getParent(); + llvmOutlinedFn = codeGenIP.getBlock()->getParent(); assert(llvmParentFn && llvmOutlinedFn && "Both parent and outlined functions must exist at this point"); @@ -3096,7 +3164,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder, // Remap access operations to declare target reference pointers for the // device, essentially generating extra loadop's as necessary if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice()) - handleDeclareTargetMapVar(mapData, moduleTranslation, builder); + handleDeclareTargetMapVar(mapData, moduleTranslation, builder, + llvmOutlinedFn); return bodyGenStatus; } |
