summaryrefslogtreecommitdiff
path: root/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
-rw-r--r--mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp61
1 files changed, 45 insertions, 16 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
index 9a30266103b1..87cb7f03fec6 100644
--- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
+++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp
@@ -150,10 +150,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
<< " operation";
};
- auto checkAligned = [&todo](auto op, LogicalResult &result) {
- if (!op.getAlignedVars().empty() || op.getAlignments())
- result = todo("aligned");
- };
auto checkAllocate = [&todo](auto op, LogicalResult &result) {
if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty())
result = todo("allocate");
@@ -275,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) {
})
.Case([&](omp::ParallelOp op) { checkAllocate(op, result); })
.Case([&](omp::SimdOp op) {
- checkAligned(op, result);
checkLinear(op, result);
checkNontemporal(op, result);
checkPrivate(op, result);
@@ -2302,6 +2297,24 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder,
llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars;
llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder());
+ llvm::BasicBlock *sourceBlock = builder.GetInsertBlock();
+ std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments();
+ mlir::OperandRange operands = simdOp.getAlignedVars();
+ for (size_t i = 0; i < operands.size(); ++i) {
+ llvm::Value *alignment = nullptr;
+ llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]);
+ llvm::Type *ty = llvmVal->getType();
+ if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) {
+ alignment = builder.getInt64(intAttr.getInt());
+ assert(ty->isPointerTy() && "Invalid type for aligned variable");
+ assert(alignment && "Invalid alignment value");
+ auto curInsert = builder.saveIP();
+ builder.SetInsertPoint(sourceBlock->getTerminator());
+ llvmVal = builder.CreateLoad(ty, llvmVal);
+ builder.restoreIP(curInsert);
+ alignedVars[llvmVal] = alignment;
+ }
+ }
ompBuilder->applySimd(loopInfo, alignedVars,
simdOp.getIfExpr()
? moduleTranslation.lookupValue(simdOp.getIfExpr())
@@ -2575,6 +2588,7 @@ static LogicalResult
convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::ModuleTranslation &moduleTranslation) {
llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder);
+ llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder();
auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst);
if (failed(checkImplementationStatus(opInst)))
@@ -2582,6 +2596,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
Value symAddr = threadprivateOp.getSymAddr();
auto *symOp = symAddr.getDefiningOp();
+
+ if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp))
+ symOp = asCast.getOperand().getDefiningOp();
+
if (!isa<LLVM::AddressOfOp>(symOp))
return opInst.emitError("Addressing symbol not found");
LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp);
@@ -2589,17 +2607,20 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder,
LLVM::GlobalOp global =
addressOfOp.getGlobal(moduleTranslation.symbolTable());
llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global);
- llvm::Type *type = globalValue->getValueType();
- llvm::TypeSize typeSize =
- builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
- type);
- llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
- llvm::StringRef suffix = llvm::StringRef(".cache", 6);
- std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str();
- llvm::Value *callInst =
- moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate(
- ompLoc, globalValue, size, cacheName);
- moduleTranslation.mapValue(opInst.getResult(0), callInst);
+
+ if (!ompBuilder->Config.isTargetDevice()) {
+ llvm::Type *type = globalValue->getValueType();
+ llvm::TypeSize typeSize =
+ builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize(
+ type);
+ llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue());
+ llvm::Value *callInst = ompBuilder->createCachedThreadPrivate(
+ ompLoc, globalValue, size, global.getSymName() + ".cache");
+ moduleTranslation.mapValue(opInst.getResult(0), callInst);
+ } else {
+ moduleTranslation.mapValue(opInst.getResult(0), globalValue);
+ }
+
return success();
}
@@ -4199,6 +4220,14 @@ static bool isTargetDeviceOp(Operation *op) {
if (op->getParentOfType<omp::TargetOp>())
return true;
+ // Certain operations return results, and whether utilised in host or
+ // target there is a chance an LLVM Dialect operation depends on it
+ // by taking it in as an operand, so we must always lower these in
+ // some manner or result in an ICE (whether they end up in a no-op
+ // or otherwise).
+ if (mlir::isa<omp::ThreadprivateOp>(op))
+ return true;
+
if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>())
if (auto declareTargetIface =
llvm::dyn_cast<mlir::omp::DeclareTargetInterface>(