diff options
Diffstat (limited to 'mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp')
| -rw-r--r-- | mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp | 61 |
1 files changed, 45 insertions, 16 deletions
diff --git a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp index 9a30266103b1..87cb7f03fec6 100644 --- a/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp +++ b/mlir/lib/Target/LLVMIR/Dialect/OpenMP/OpenMPToLLVMIRTranslation.cpp @@ -150,10 +150,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { << " operation"; }; - auto checkAligned = [&todo](auto op, LogicalResult &result) { - if (!op.getAlignedVars().empty() || op.getAlignments()) - result = todo("aligned"); - }; auto checkAllocate = [&todo](auto op, LogicalResult &result) { if (!op.getAllocateVars().empty() || !op.getAllocatorVars().empty()) result = todo("allocate"); @@ -275,7 +271,6 @@ static LogicalResult checkImplementationStatus(Operation &op) { }) .Case([&](omp::ParallelOp op) { checkAllocate(op, result); }) .Case([&](omp::SimdOp op) { - checkAligned(op, result); checkLinear(op, result); checkNontemporal(op, result); checkPrivate(op, result); @@ -2302,6 +2297,24 @@ convertOmpSimd(Operation &opInst, llvm::IRBuilderBase &builder, llvm::MapVector<llvm::Value *, llvm::Value *> alignedVars; llvm::omp::OrderKind order = convertOrderKind(simdOp.getOrder()); + llvm::BasicBlock *sourceBlock = builder.GetInsertBlock(); + std::optional<ArrayAttr> alignmentValues = simdOp.getAlignments(); + mlir::OperandRange operands = simdOp.getAlignedVars(); + for (size_t i = 0; i < operands.size(); ++i) { + llvm::Value *alignment = nullptr; + llvm::Value *llvmVal = moduleTranslation.lookupValue(operands[i]); + llvm::Type *ty = llvmVal->getType(); + if (auto intAttr = llvm::dyn_cast<IntegerAttr>((*alignmentValues)[i])) { + alignment = builder.getInt64(intAttr.getInt()); + assert(ty->isPointerTy() && "Invalid type for aligned variable"); + assert(alignment && "Invalid alignment value"); + auto curInsert = builder.saveIP(); + builder.SetInsertPoint(sourceBlock->getTerminator()); + llvmVal = builder.CreateLoad(ty, llvmVal); + builder.restoreIP(curInsert); + alignedVars[llvmVal] = alignment; + } + } ompBuilder->applySimd(loopInfo, alignedVars, simdOp.getIfExpr() ? moduleTranslation.lookupValue(simdOp.getIfExpr()) @@ -2575,6 +2588,7 @@ static LogicalResult convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::ModuleTranslation &moduleTranslation) { llvm::OpenMPIRBuilder::LocationDescription ompLoc(builder); + llvm::OpenMPIRBuilder *ompBuilder = moduleTranslation.getOpenMPBuilder(); auto threadprivateOp = cast<omp::ThreadprivateOp>(opInst); if (failed(checkImplementationStatus(opInst))) @@ -2582,6 +2596,10 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, Value symAddr = threadprivateOp.getSymAddr(); auto *symOp = symAddr.getDefiningOp(); + + if (auto asCast = dyn_cast<LLVM::AddrSpaceCastOp>(symOp)) + symOp = asCast.getOperand().getDefiningOp(); + if (!isa<LLVM::AddressOfOp>(symOp)) return opInst.emitError("Addressing symbol not found"); LLVM::AddressOfOp addressOfOp = dyn_cast<LLVM::AddressOfOp>(symOp); @@ -2589,17 +2607,20 @@ convertOmpThreadprivate(Operation &opInst, llvm::IRBuilderBase &builder, LLVM::GlobalOp global = addressOfOp.getGlobal(moduleTranslation.symbolTable()); llvm::GlobalValue *globalValue = moduleTranslation.lookupGlobal(global); - llvm::Type *type = globalValue->getValueType(); - llvm::TypeSize typeSize = - builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( - type); - llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue()); - llvm::StringRef suffix = llvm::StringRef(".cache", 6); - std::string cacheName = (Twine(global.getSymName()).concat(suffix)).str(); - llvm::Value *callInst = - moduleTranslation.getOpenMPBuilder()->createCachedThreadPrivate( - ompLoc, globalValue, size, cacheName); - moduleTranslation.mapValue(opInst.getResult(0), callInst); + + if (!ompBuilder->Config.isTargetDevice()) { + llvm::Type *type = globalValue->getValueType(); + llvm::TypeSize typeSize = + builder.GetInsertBlock()->getModule()->getDataLayout().getTypeStoreSize( + type); + llvm::ConstantInt *size = builder.getInt64(typeSize.getFixedValue()); + llvm::Value *callInst = ompBuilder->createCachedThreadPrivate( + ompLoc, globalValue, size, global.getSymName() + ".cache"); + moduleTranslation.mapValue(opInst.getResult(0), callInst); + } else { + moduleTranslation.mapValue(opInst.getResult(0), globalValue); + } + return success(); } @@ -4199,6 +4220,14 @@ static bool isTargetDeviceOp(Operation *op) { if (op->getParentOfType<omp::TargetOp>()) return true; + // Certain operations return results, and whether utilised in host or + // target there is a chance an LLVM Dialect operation depends on it + // by taking it in as an operand, so we must always lower these in + // some manner or result in an ICE (whether they end up in a no-op + // or otherwise). + if (mlir::isa<omp::ThreadprivateOp>(op)) + return true; + if (auto parentFn = op->getParentOfType<LLVM::LLVMFuncOp>()) if (auto declareTargetIface = llvm::dyn_cast<mlir::omp::DeclareTargetInterface>( |
