summaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp165
1 files changed, 117 insertions, 48 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 6ec4c120c11e..eabc4b30f57a 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -30,6 +30,7 @@
#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
#include "llvm/IR/DebugInfoMetadata.h"
#include "llvm/IR/IRBuilder.h"
+#include "llvm/IR/ReplaceConstant.h"
#include "llvm/Support/FileSystem.h"
#include "llvm/TargetParser/Triple.h"
#include "llvm/Transforms/Utils/ModuleUtils.h"
@@ -388,8 +389,19 @@ static LogicalResult inlineConvertOmpRegions(
// be processed multiple times.
moduleTranslation.forgetMapping(region);
- if (potentialTerminator && potentialTerminator->isTerminator())
- potentialTerminator->insertAfter(&builder.GetInsertBlock()->back());
+ if (potentialTerminator && potentialTerminator->isTerminator()) {
+ llvm::BasicBlock *block = builder.GetInsertBlock();
+ if (block->empty()) {
+ // this can happen for really simple reduction init regions e.g.
+ // %0 = llvm.mlir.constant(0 : i32) : i32
+ // omp.yield(%0 : i32)
+ // because the llvm.mlir.constant (MLIR op) isn't converted into any
+ // llvm op
+ potentialTerminator->insertInto(block, block->begin());
+ } else {
+ potentialTerminator->insertAfter(&block->back());
+ }
+ }
return success();
}
@@ -762,7 +774,7 @@ convertOmpTaskgroupOp(omp::TaskgroupOp tgOp, llvm::IRBuilderBase &builder,
/// Allocate space for privatized reduction variables.
template <typename T>
static void allocByValReductionVars(
- T loop, llvm::IRBuilderBase &builder,
+ T loop, ArrayRef<BlockArgument> reductionArgs, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation,
llvm::OpenMPIRBuilder::InsertPointTy &allocaIP,
SmallVectorImpl<omp::DeclareReductionOp> &reductionDecls,
@@ -770,17 +782,15 @@ static void allocByValReductionVars(
DenseMap<Value, llvm::Value *> &reductionVariableMap,
llvm::ArrayRef<bool> isByRefs) {
llvm::IRBuilderBase::InsertPointGuard guard(builder);
- builder.restoreIP(allocaIP);
- auto args =
- loop.getRegion().getArguments().take_back(loop.getNumReductionVars());
+ builder.SetInsertPoint(allocaIP.getBlock()->getTerminator());
for (std::size_t i = 0; i < loop.getNumReductionVars(); ++i) {
if (isByRefs[i])
continue;
llvm::Value *var = builder.CreateAlloca(
moduleTranslation.convertType(reductionDecls[i].getType()));
- moduleTranslation.mapValue(args[i], var);
- privateReductionVariables.push_back(var);
+ moduleTranslation.mapValue(reductionArgs[i], var);
+ privateReductionVariables[i] = var;
reductionVariableMap.try_emplace(loop.getReductionVars()[i], var);
}
}
@@ -826,14 +836,17 @@ static void collectReductionInfo(
// Collect the reduction information.
reductionInfos.reserve(numReductions);
for (unsigned i = 0; i < numReductions; ++i) {
- llvm::OpenMPIRBuilder::AtomicReductionGenTy atomicGen = nullptr;
+ llvm::OpenMPIRBuilder::ReductionGenAtomicCBTy atomicGen = nullptr;
if (owningAtomicReductionGens[i])
atomicGen = owningAtomicReductionGens[i];
llvm::Value *variable =
moduleTranslation.lookupValue(loop.getReductionVars()[i]);
reductionInfos.push_back(
{moduleTranslation.convertType(reductionDecls[i].getType()), variable,
- privateReductionVariables[i], owningReductionGens[i], atomicGen});
+ privateReductionVariables[i],
+ /*EvaluationKind=*/llvm::OpenMPIRBuilder::EvalKind::Scalar,
+ owningReductionGens[i],
+ /*ReductionGenClang=*/nullptr, atomicGen});
}
}
@@ -911,16 +924,20 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::OpenMPIRBuilder::InsertPointTy allocaIP =
findAllocaInsertPoint(builder, moduleTranslation);
- SmallVector<llvm::Value *> privateReductionVariables;
+ SmallVector<llvm::Value *> privateReductionVariables(
+ wsloopOp.getNumReductionVars());
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocByValReductionVars(wsloopOp, builder, moduleTranslation, allocaIP,
- reductionDecls, privateReductionVariables,
+
+ MutableArrayRef<BlockArgument> reductionArgs =
+ wsloopOp.getRegion().getArguments();
+
+ allocByValReductionVars(wsloopOp, reductionArgs, builder, moduleTranslation,
+ allocaIP, reductionDecls, privateReductionVariables,
reductionVariableMap, isByRef);
// Before the loop, store the initial values of reductions into reduction
// variables. Although this could be done after allocas, we don't want to mess
// up with the alloca insertion point.
- ArrayRef<BlockArgument> reductionArgs = wsloopOp.getRegion().getArguments();
for (unsigned i = 0; i < wsloopOp.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
@@ -942,7 +959,7 @@ convertOmpWsloop(Operation &opInst, llvm::IRBuilderBase &builder,
// ptr
builder.CreateStore(phis[0], var);
- privateReductionVariables.push_back(var);
+ privateReductionVariables[i] = var;
moduleTranslation.mapValue(reductionArgs[i], phis[0]);
reductionVariableMap.try_emplace(wsloopOp.getReductionVars()[i], phis[0]);
} else {
@@ -1140,20 +1157,40 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// Collect reduction declarations
SmallVector<omp::DeclareReductionOp> reductionDecls;
collectReductionDecls(opInst, reductionDecls);
- SmallVector<llvm::Value *> privateReductionVariables;
+ SmallVector<llvm::Value *> privateReductionVariables(
+ opInst.getNumReductionVars());
auto bodyGenCB = [&](InsertPointTy allocaIP, InsertPointTy codeGenIP) {
// Allocate reduction vars
DenseMap<Value, llvm::Value *> reductionVariableMap;
- allocByValReductionVars(opInst, builder, moduleTranslation, allocaIP,
- reductionDecls, privateReductionVariables,
+
+ MutableArrayRef<BlockArgument> reductionArgs =
+ opInst.getRegion().getArguments().slice(
+ opInst.getNumAllocateVars() + opInst.getNumAllocatorsVars(),
+ opInst.getNumReductionVars());
+
+ allocByValReductionVars(opInst, reductionArgs, builder, moduleTranslation,
+ allocaIP, reductionDecls, privateReductionVariables,
reductionVariableMap, isByRef);
// Initialize reduction vars
builder.restoreIP(allocaIP);
- MutableArrayRef<BlockArgument> reductionArgs =
- opInst.getRegion().getArguments().take_back(
- opInst.getNumReductionVars());
+ llvm::BasicBlock *initBlock = splitBB(builder, true, "omp.reduction.init");
+ allocaIP =
+ InsertPointTy(allocaIP.getBlock(),
+ allocaIP.getBlock()->getTerminator()->getIterator());
+ SmallVector<llvm::Value *> byRefVars(opInst.getNumReductionVars());
+ for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
+ if (isByRef[i]) {
+ // Allocate reduction variable (which is a pointer to the real reduciton
+ // variable allocated in the inlined region)
+ byRefVars[i] = builder.CreateAlloca(
+ moduleTranslation.convertType(reductionDecls[i].getType()));
+ }
+ }
+
+ builder.SetInsertPoint(initBlock->getFirstNonPHIOrDbgOrAlloca());
+
for (unsigned i = 0; i < opInst.getNumReductionVars(); ++i) {
SmallVector<llvm::Value *> phis;
@@ -1166,18 +1203,17 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
assert(phis.size() == 1 &&
"expected one value to be yielded from the "
"reduction neutral element declaration region");
- builder.restoreIP(allocaIP);
+
+ // mapInitializationArg finishes its block with a terminator. We need to
+ // insert before that terminator.
+ builder.SetInsertPoint(builder.GetInsertBlock()->getTerminator());
if (isByRef[i]) {
- // Allocate reduction variable (which is a pointer to the real reduciton
- // variable allocated in the inlined region)
- llvm::Value *var = builder.CreateAlloca(
- moduleTranslation.convertType(reductionDecls[i].getType()));
// Store the result of the inlined region to the allocated reduction var
// ptr
- builder.CreateStore(phis[0], var);
+ builder.CreateStore(phis[0], byRefVars[i]);
- privateReductionVariables.push_back(var);
+ privateReductionVariables[i] = byRefVars[i];
moduleTranslation.mapValue(reductionArgs[i], phis[0]);
reductionVariableMap.try_emplace(opInst.getReductionVars()[i], phis[0]);
} else {
@@ -1275,7 +1311,26 @@ convertOmpParallel(omp::ParallelOp opInst, llvm::IRBuilderBase &builder,
// region. The privatizer is processed in-place (see below) before it
// gets inlined in the parallel region and therefore processing the
// original op is dangerous.
- return {privVar, privatizer.clone()};
+
+ MLIRContext &context = moduleTranslation.getContext();
+ mlir::IRRewriter opCloner(&context);
+ opCloner.setInsertionPoint(privatizer);
+ auto clone = llvm::cast<mlir::omp::PrivateClauseOp>(
+ opCloner.clone(*privatizer));
+
+ // Unique the clone name to avoid clashes in the symbol table.
+ unsigned counter = 0;
+ SmallString<256> cloneName = SymbolTable::generateSymbolName<256>(
+ privatizer.getSymName(),
+ [&](llvm::StringRef candidate) {
+ return SymbolTable::lookupNearestSymbolFrom(
+ opInst, StringAttr::get(&context, candidate)) !=
+ nullptr;
+ },
+ counter);
+
+ clone.setSymName(cloneName);
+ return {privVar, clone};
}
}
@@ -1925,12 +1980,6 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
Operation *clauseOp, llvm::Value *basePointer,
llvm::Type *baseType, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
- // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
- // the size in inconsistent byte or bit format.
- uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
- if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
- underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
-
if (auto memberClause =
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(clauseOp)) {
// This calculates the size to transfer based on bounds and the underlying
@@ -1956,6 +2005,12 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
}
}
+ // utilising getTypeSizeInBits instead of getTypeSize as getTypeSize gives
+ // the size in inconsistent byte or bit format.
+ uint64_t underlyingTypeSzInBits = dl.getTypeSizeInBits(type);
+ if (auto arrTy = llvm::dyn_cast_if_present<LLVM::LLVMArrayType>(type))
+ underlyingTypeSzInBits = getArrayElementSizeInBits(arrTy, dl);
+
// The size in bytes x number of elements, the sizeInBytes stored is
// the underyling types size, e.g. if ptr<i32>, it'll be the i32's
// size, so we do some on the fly runtime math to get the size in
@@ -1966,7 +2021,7 @@ llvm::Value *getSizeInBytes(DataLayout &dl, const mlir::Type &type,
}
}
- return builder.getInt64(underlyingTypeSzInBits / 8);
+ return builder.getInt64(dl.getTypeSizeInBits(type) / 8);
}
void collectMapDataFromMapOperands(MapInfoData &mapData,
@@ -2263,7 +2318,7 @@ static llvm::omp::OpenMPOffloadMappingFlags mapParentWithMembers(
// This creates the initial MEMBER_OF mapping that consists of
// the parent/top level container (same as above effectively, except
- // with a fixed initial compile time size and seperate maptype which
+ // with a fixed initial compile time size and separate maptype which
// indicates the true mape type (tofrom etc.). This parent mapping is
// only relevant if the structure in its totality is being mapped,
// otherwise the above suffices.
@@ -2388,7 +2443,7 @@ static void processMapWithMembersOf(
// If we have a partial map (no parent referenced in the map clauses of the
// directive, only members) and only a single member, we do not need to bind
- // the map of the member to the parent, we can pass the member seperately.
+ // the map of the member to the parent, we can pass the member separately.
if (parentClause.getMembers().size() == 1 && parentClause.getPartialMap()) {
auto memberClause = llvm::cast<mlir::omp::MapInfoOp>(
parentClause.getMembers()[0].getDefiningOp());
@@ -2425,7 +2480,7 @@ createAlteredByCaptureMap(MapInfoData &mapData,
LLVM::ModuleTranslation &moduleTranslation,
llvm::IRBuilderBase &builder) {
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
- // if it's declare target, skip it, it's handled seperately.
+ // if it's declare target, skip it, it's handled separately.
if (!mapData.IsDeclareTarget[i]) {
auto mapOp =
mlir::dyn_cast_if_present<mlir::omp::MapInfoOp>(mapData.MapClause[i]);
@@ -2847,7 +2902,7 @@ static bool targetOpSupported(Operation &opInst) {
static void
handleDeclareTargetMapVar(MapInfoData &mapData,
LLVM::ModuleTranslation &moduleTranslation,
- llvm::IRBuilderBase &builder) {
+ llvm::IRBuilderBase &builder, llvm::Function *func) {
for (size_t i = 0; i < mapData.MapClause.size(); ++i) {
// In the case of declare target mapped variables, the basePointer is
// the reference pointer generated by the convertDeclareTargetAttr
@@ -2862,19 +2917,31 @@ handleDeclareTargetMapVar(MapInfoData &mapData,
// reference pointer and the pointer are assigned in the kernel argument
// structure for the host.
if (mapData.IsDeclareTarget[i]) {
+ // If the original map value is a constant, then we have to make sure all
+ // of it's uses within the current kernel/function that we are going to
+ // rewrite are converted to instructions, as we will be altering the old
+ // use (OriginalValue) from a constant to an instruction, which will be
+ // illegal and ICE the compiler if the user is a constant expression of
+ // some kind e.g. a constant GEP.
+ if (auto *constant = dyn_cast<llvm::Constant>(mapData.OriginalValue[i]))
+ convertUsersOfConstantsToInstructions(constant, func, false);
+
// The users iterator will get invalidated if we modify an element,
- // so we populate this vector of uses to alter each user on an individual
- // basis to emit its own load (rather than one load for all).
+ // so we populate this vector of uses to alter each user on an
+ // individual basis to emit its own load (rather than one load for
+ // all).
llvm::SmallVector<llvm::User *> userVec;
for (llvm::User *user : mapData.OriginalValue[i]->users())
userVec.push_back(user);
for (llvm::User *user : userVec) {
if (auto *insn = dyn_cast<llvm::Instruction>(user)) {
- auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
- mapData.BasePointers[i]);
- load->moveBefore(insn);
- user->replaceUsesOfWith(mapData.OriginalValue[i], load);
+ if (insn->getFunction() == func) {
+ auto *load = builder.CreateLoad(mapData.BasePointers[i]->getType(),
+ mapData.BasePointers[i]);
+ load->moveBefore(insn);
+ user->replaceUsesOfWith(mapData.OriginalValue[i], load);
+ }
}
}
}
@@ -2992,6 +3059,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
auto &targetRegion = targetOp.getRegion();
DataLayout dl = DataLayout(opInst.getParentOfType<ModuleOp>());
SmallVector<Value> mapOperands = targetOp.getMapOperands();
+ llvm::Function *llvmOutlinedFn = nullptr;
LogicalResult bodyGenStatus = success();
using InsertPointTy = llvm::OpenMPIRBuilder::InsertPointTy;
@@ -3001,7 +3069,7 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
// original function to the new outlined function.
llvm::Function *llvmParentFn =
moduleTranslation.lookupFunction(parentFn.getName());
- llvm::Function *llvmOutlinedFn = codeGenIP.getBlock()->getParent();
+ llvmOutlinedFn = codeGenIP.getBlock()->getParent();
assert(llvmParentFn && llvmOutlinedFn &&
"Both parent and outlined functions must exist at this point");
@@ -3096,7 +3164,8 @@ convertOmpTarget(Operation &opInst, llvm::IRBuilderBase &builder,
// Remap access operations to declare target reference pointers for the
// device, essentially generating extra loadop's as necessary
if (moduleTranslation.getOpenMPBuilder()->Config.isTargetDevice())
- handleDeclareTargetMapVar(mapData, moduleTranslation, builder);
+ handleDeclareTargetMapVar(mapData, moduleTranslation, builder,
+ llvmOutlinedFn);
return bodyGenStatus;
}