summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp')
-rw-r--r--mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp442
1 files changed, 442 insertions, 0 deletions
diff --git a/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
new file mode 100644
index 000000000000..db54eaa2628a
--- /dev/null
+++ b/mlir/lib/Dialect/OpenMP/Transforms/OpenMPOffloadPrivatizationPrepare.cpp
@@ -0,0 +1,442 @@
+//===- OpenMPOffloadPrivatizationPrepare.cpp - Prepare OMP privatization --===//
+//
+// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions.
+// See https://llvm.org/LICENSE.txt for license information.
+// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception
+//
+//===----------------------------------------------------------------------===//
+
+#include "mlir/Analysis/SliceAnalysis.h"
+#include "mlir/Dialect/LLVMIR/FunctionCallUtils.h"
+#include "mlir/Dialect/LLVMIR/LLVMDialect.h"
+#include "mlir/Dialect/OpenMP/OpenMPDialect.h"
+#include "mlir/IR/Builders.h"
+#include "mlir/IR/Dominance.h"
+#include "mlir/IR/IRMapping.h"
+#include "mlir/Pass/Pass.h"
+#include "mlir/Support/LLVM.h"
+#include "llvm/Support/DebugLog.h"
+#include "llvm/Support/FormatVariadic.h"
+#include <cstdint>
+#include <iterator>
+#include <utility>
+
+//===----------------------------------------------------------------------===//
+// A pass that prepares OpenMP code for translation of delayed privatization
+// in the context of deferred target tasks. Deferred target tasks are created
+// when the nowait clause is used on the target directive.
+//===----------------------------------------------------------------------===//
+
+#define DEBUG_TYPE "omp-prepare-for-offload-privatization"
+
+namespace mlir {
+namespace omp {
+
+#define GEN_PASS_DEF_PREPAREFOROMPOFFLOADPRIVATIZATIONPASS
+#include "mlir/Dialect/OpenMP/Transforms/Passes.h.inc"
+
+} // namespace omp
+} // namespace mlir
+
+using namespace mlir;
+namespace {
+
+//===----------------------------------------------------------------------===//
+// PrepareForOMPOffloadPrivatizationPass
+//===----------------------------------------------------------------------===//
+
+class PrepareForOMPOffloadPrivatizationPass
+ : public omp::impl::PrepareForOMPOffloadPrivatizationPassBase<
+ PrepareForOMPOffloadPrivatizationPass> {
+
+ void runOnOperation() override {
+ ModuleOp mod = getOperation();
+
+ // In this pass, we make host-allocated privatized variables persist for
+ // deferred target tasks by copying them to the heap. Once the target task
+ // is done, this heap memory is freed. Since all of this happens on the host
+ // we can skip device modules.
+ auto offloadModuleInterface =
+ dyn_cast<omp::OffloadModuleInterface>(mod.getOperation());
+ if (offloadModuleInterface && offloadModuleInterface.getIsTargetDevice())
+ return;
+
+ getOperation()->walk([&](omp::TargetOp targetOp) {
+ if (!hasPrivateVars(targetOp) || !isTargetTaskDeferred(targetOp))
+ return;
+ IRRewriter rewriter(&getContext());
+ OperandRange privateVars = targetOp.getPrivateVars();
+ SmallVector<mlir::Value> newPrivVars;
+ Value fakeDependVar;
+ omp::TaskOp cleanupTaskOp;
+
+ newPrivVars.reserve(privateVars.size());
+ std::optional<ArrayAttr> privateSyms = targetOp.getPrivateSyms();
+ for (auto [privVarIdx, privVarSymPair] :
+ llvm::enumerate(llvm::zip_equal(privateVars, *privateSyms))) {
+ Value privVar = std::get<0>(privVarSymPair);
+ Attribute privSym = std::get<1>(privVarSymPair);
+
+ omp::PrivateClauseOp privatizer = findPrivatizer(targetOp, privSym);
+ if (!privatizer.needsMap()) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+ bool isFirstPrivate = privatizer.getDataSharingType() ==
+ omp::DataSharingClauseType::FirstPrivate;
+
+ Value mappedValue = targetOp.getMappedValueForPrivateVar(privVarIdx);
+ auto mapInfoOp = cast<omp::MapInfoOp>(mappedValue.getDefiningOp());
+
+ if (mapInfoOp.getMapCaptureType() == omp::VariableCaptureKind::ByCopy) {
+ newPrivVars.push_back(privVar);
+ continue;
+ }
+
+ // For deferred target tasks (!$omp target nowait), we need to keep
+ // a copy of the original, i.e. host variable being privatized so
+ // that it is available when the target task is eventually executed.
+ // We do this by first allocating as much heap memory as is needed by
+ // the original variable. Then, we use the init and copy regions of the
+ // privatizer, an instance of omp::PrivateClauseOp to set up the heap-
+ // allocated copy.
+ // After the target task is done, we need to use the dealloc region
+ // of the privatizer to clean up everything. We also need to free
+ // the heap memory we allocated. But due to the deferred nature
+ // of the target task, we cannot simply deallocate right after the
+ // omp.target operation else we may end up freeing memory before
+ // its eventual use by the target task. So, we create a dummy
+ // dependence between the target task and new omp.task. In the omp.task,
+ // we do all the cleanup. So, we end up with the following structure
+ //
+ // omp.target map_entries(..) ... nowait depend(out:fakeDependVar) {
+ // ...
+ // omp.terminator
+ // }
+ // omp.task depend(in: fakeDependVar) {
+ // /*cleanup_code*/
+ // omp.terminator
+ // }
+ // fakeDependVar is the address of the first heap-allocated copy of the
+ // host variable being privatized.
+
+ bool needsCleanupTask = !privatizer.getDeallocRegion().empty();
+
+ // Allocate heap memory that corresponds to the type of memory
+ // pointed to by varPtr
+ // For boxchars this won't be a pointer. But, MapsForPrivatizedSymbols
+ // should have mapped the pointer to the boxchar so use that as varPtr.
+ Value varPtr = mapInfoOp.getVarPtr();
+ Type varType = mapInfoOp.getVarType();
+ bool isPrivatizedByValue =
+ !isa<LLVM::LLVMPointerType>(privVar.getType());
+
+ assert(isa<LLVM::LLVMPointerType>(varPtr.getType()));
+ Value heapMem =
+ allocateHeapMem(targetOp, varPtr, varType, mod, rewriter);
+ if (!heapMem)
+ targetOp.emitError(
+ "Unable to allocate heap memory when trying to move "
+ "a private variable out of the stack and into the "
+ "heap for use by a deferred target task");
+
+ if (needsCleanupTask && !fakeDependVar)
+ fakeDependVar = heapMem;
+
+ // The types of private vars should match before and after the
+ // transformation. In particular, if the type is a pointer,
+ // simply record the newly allocated malloc location as the
+ // new private variable. If, however, the type is not a pointer
+ // then, we need to load the value from the newly allocated
+ // location. We'll insert that load later after we have updated
+ // the malloc'd location with the contents of the original
+ // variable.
+ if (!isPrivatizedByValue)
+ newPrivVars.push_back(heapMem);
+
+ // We now need to copy the original private variable into the newly
+ // allocated location in the heap.
+ // Find the earliest insertion point for the copy. This will be before
+ // the first in the list of omp::MapInfoOp instances that use varPtr.
+ // After the copy these omp::MapInfoOp instances will refer to heapMem
+ // instead.
+ Operation *varPtrDefiningOp = varPtr.getDefiningOp();
+ DenseSet<Operation *> users;
+ if (varPtrDefiningOp) {
+ users.insert(varPtrDefiningOp->user_begin(),
+ varPtrDefiningOp->user_end());
+ } else {
+ auto blockArg = cast<BlockArgument>(varPtr);
+ users.insert(blockArg.user_begin(), blockArg.user_end());
+ }
+ auto usesVarPtr = [&users](Operation *op) -> bool {
+ return users.count(op);
+ };
+
+ SmallVector<Operation *> chainOfOps;
+ chainOfOps.push_back(mapInfoOp);
+ for (auto member : mapInfoOp.getMembers()) {
+ omp::MapInfoOp memberMap =
+ cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (usesVarPtr(memberMap))
+ chainOfOps.push_back(memberMap);
+ if (memberMap.getVarPtrPtr()) {
+ Operation *defOp = memberMap.getVarPtrPtr().getDefiningOp();
+ if (defOp && usesVarPtr(defOp))
+ chainOfOps.push_back(defOp);
+ }
+ }
+
+ DominanceInfo dom;
+ llvm::sort(chainOfOps, [&](Operation *l, Operation *r) {
+ return dom.dominates(l, r);
+ });
+
+ rewriter.setInsertionPoint(chainOfOps.front());
+
+ Operation *firstOp = chainOfOps.front();
+ Location loc = firstOp->getLoc();
+
+ // Create a llvm.func for 'region' that is marked always_inline and call
+ // it.
+ auto createAlwaysInlineFuncAndCallIt =
+ [&](Region &region, llvm::StringRef funcName,
+ llvm::ArrayRef<Value> args, bool returnsValue) -> Value {
+ assert(!region.empty() && "region cannot be empty");
+ LLVM::LLVMFuncOp func = createFuncOpForRegion(
+ loc, mod, region, funcName, rewriter, returnsValue);
+ auto call = rewriter.create<LLVM::CallOp>(loc, func, args);
+ return call.getResult();
+ };
+
+ Value moldArg, newArg;
+ if (isPrivatizedByValue) {
+ moldArg = rewriter.create<LLVM::LoadOp>(loc, varType, varPtr);
+ newArg = rewriter.create<LLVM::LoadOp>(loc, varType, heapMem);
+ } else {
+ moldArg = varPtr;
+ newArg = heapMem;
+ }
+
+ Value initializedVal;
+ if (!privatizer.getInitRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getInitRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "init").str(),
+ {moldArg, newArg}, /*returnsValue=*/true);
+ else
+ initializedVal = newArg;
+
+ if (isFirstPrivate && !privatizer.getCopyRegion().empty())
+ initializedVal = createAlwaysInlineFuncAndCallIt(
+ privatizer.getCopyRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "copy").str(),
+ {moldArg, initializedVal}, /*returnsValue=*/true);
+
+ if (isPrivatizedByValue)
+ (void)rewriter.create<LLVM::StoreOp>(loc, initializedVal, heapMem);
+
+ // clone origOp, replace all uses of varPtr with heapMem and
+ // erase origOp.
+ auto cloneModifyAndErase = [&](Operation *origOp) -> Operation * {
+ Operation *clonedOp = rewriter.clone(*origOp);
+ rewriter.replaceAllOpUsesWith(origOp, clonedOp);
+ rewriter.modifyOpInPlace(clonedOp, [&]() {
+ clonedOp->replaceUsesOfWith(varPtr, heapMem);
+ });
+ rewriter.eraseOp(origOp);
+ return clonedOp;
+ };
+
+ // Now that we have set up the heap-allocated copy of the private
+ // variable, rewrite all the uses of the original variable with
+ // the heap-allocated variable.
+ rewriter.setInsertionPoint(targetOp);
+ rewriter.setInsertionPoint(cloneModifyAndErase(mapInfoOp));
+
+ // Fix any members that may use varPtr to now use heapMem
+ for (auto member : mapInfoOp.getMembers()) {
+ auto memberMapInfoOp = cast<omp::MapInfoOp>(member.getDefiningOp());
+ if (!usesVarPtr(memberMapInfoOp))
+ continue;
+ rewriter.setInsertionPoint(cloneModifyAndErase(memberMapInfoOp));
+
+ if (memberMapInfoOp.getVarPtrPtr()) {
+ Operation *varPtrPtrdefOp =
+ memberMapInfoOp.getVarPtrPtr().getDefiningOp();
+ rewriter.setInsertionPoint(cloneModifyAndErase(varPtrPtrdefOp));
+ }
+ }
+
+ // If the type of the private variable is not a pointer,
+ // which is typically the case with !fir.boxchar types, then
+ // we need to ensure that the new private variable is also
+ // not a pointer. Insert a load from heapMem right before
+ // targetOp.
+ if (isPrivatizedByValue) {
+ rewriter.setInsertionPoint(targetOp);
+ auto newPrivVar = rewriter.create<LLVM::LoadOp>(mapInfoOp.getLoc(),
+ varType, heapMem);
+ newPrivVars.push_back(newPrivVar);
+ }
+
+ // Deallocate
+ if (needsCleanupTask) {
+ if (!cleanupTaskOp) {
+ assert(fakeDependVar &&
+ "Need a valid value to set up a dependency");
+ rewriter.setInsertionPointAfter(targetOp);
+ omp::TaskOperands taskOperands;
+ auto inDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependin);
+ taskOperands.dependKinds.push_back(inDepend);
+ taskOperands.dependVars.push_back(fakeDependVar);
+ cleanupTaskOp = omp::TaskOp::create(rewriter, loc, taskOperands);
+ Block *taskBlock = rewriter.createBlock(&cleanupTaskOp.getRegion());
+ rewriter.setInsertionPointToEnd(taskBlock);
+ rewriter.create<omp::TerminatorOp>(cleanupTaskOp.getLoc());
+ }
+ rewriter.setInsertionPointToStart(
+ &*cleanupTaskOp.getRegion().getBlocks().begin());
+ (void)createAlwaysInlineFuncAndCallIt(
+ privatizer.getDeallocRegion(),
+ llvm::formatv("{0}_{1}", privatizer.getSymName(), "dealloc")
+ .str(),
+ {initializedVal}, /*returnsValue=*/false);
+ llvm::FailureOr<LLVM::LLVMFuncOp> freeFunc =
+ LLVM::lookupOrCreateFreeFn(rewriter, mod);
+ assert(llvm::succeeded(freeFunc) &&
+ "Could not find free in the module");
+ (void)rewriter.create<LLVM::CallOp>(loc, freeFunc.value(),
+ ValueRange{heapMem});
+ }
+ }
+ assert(newPrivVars.size() == privateVars.size() &&
+ "The number of private variables must match before and after "
+ "transformation");
+ if (fakeDependVar) {
+ omp::ClauseTaskDependAttr outDepend = omp::ClauseTaskDependAttr::get(
+ rewriter.getContext(), omp::ClauseTaskDepend::taskdependout);
+ SmallVector<Attribute> newDependKinds;
+ if (!targetOp.getDependVars().empty()) {
+ std::optional<ArrayAttr> dependKinds = targetOp.getDependKinds();
+ assert(dependKinds && "bad depend clause in omp::TargetOp");
+ llvm::copy(*dependKinds, std::back_inserter(newDependKinds));
+ }
+ newDependKinds.push_back(outDepend);
+ ArrayAttr newDependKindsAttr =
+ ArrayAttr::get(rewriter.getContext(), newDependKinds);
+ targetOp.getDependVarsMutable().append(fakeDependVar);
+ targetOp.setDependKindsAttr(newDependKindsAttr);
+ }
+ rewriter.setInsertionPoint(targetOp);
+ targetOp.getPrivateVarsMutable().clear();
+ targetOp.getPrivateVarsMutable().assign(newPrivVars);
+ });
+ }
+
+private:
+ bool hasPrivateVars(omp::TargetOp targetOp) const {
+ return !targetOp.getPrivateVars().empty();
+ }
+
+ bool isTargetTaskDeferred(omp::TargetOp targetOp) const {
+ return targetOp.getNowait();
+ }
+
+ template <typename OpTy>
+ omp::PrivateClauseOp findPrivatizer(OpTy op, Attribute privSym) const {
+ SymbolRefAttr privatizerName = llvm::cast<SymbolRefAttr>(privSym);
+ omp::PrivateClauseOp privatizer =
+ SymbolTable::lookupNearestSymbolFrom<omp::PrivateClauseOp>(
+ op, privatizerName);
+ return privatizer;
+ }
+
+ // Get the (compile-time constant) size of varType as per the
+ // given DataLayout dl.
+ std::int64_t getSizeInBytes(const DataLayout &dl, Type varType) const {
+ llvm::TypeSize size = dl.getTypeSize(varType);
+ unsigned short alignment = dl.getTypeABIAlignment(varType);
+ return llvm::alignTo(size, alignment);
+ }
+
+ LLVM::LLVMFuncOp getMalloc(ModuleOp mod, IRRewriter &rewriter) const {
+ llvm::FailureOr<LLVM::LLVMFuncOp> mallocCall =
+ LLVM::lookupOrCreateMallocFn(rewriter, mod, rewriter.getI64Type());
+ assert(llvm::succeeded(mallocCall) &&
+ "Could not find malloc in the module");
+ return mallocCall.value();
+ }
+
+ Value allocateHeapMem(omp::TargetOp targetOp, Value privVar, Type varType,
+ ModuleOp mod, IRRewriter &rewriter) const {
+ OpBuilder::InsertionGuard guard(rewriter);
+ Value varPtr = privVar;
+ Operation *definingOp = varPtr.getDefiningOp();
+ BlockArgument blockArg;
+ if (!definingOp) {
+ blockArg = mlir::dyn_cast<BlockArgument>(varPtr);
+ rewriter.setInsertionPointToStart(blockArg.getParentBlock());
+ } else {
+ rewriter.setInsertionPoint(definingOp);
+ }
+ Location loc = definingOp ? definingOp->getLoc() : blockArg.getLoc();
+ LLVM::LLVMFuncOp mallocFn = getMalloc(mod, rewriter);
+
+ assert(mod.getDataLayoutSpec() &&
+ "MLIR module with no datalayout spec not handled yet");
+
+ const DataLayout &dl = DataLayout(mod);
+ std::int64_t distance = getSizeInBytes(dl, varType);
+
+ Value sizeBytes = rewriter.create<LLVM::ConstantOp>(
+ loc, mallocFn.getFunctionType().getParamType(0), distance);
+
+ auto mallocCallOp =
+ rewriter.create<LLVM::CallOp>(loc, mallocFn, ValueRange{sizeBytes});
+ return mallocCallOp.getResult();
+ }
+
+ // Create a function for srcRegion and attribute it to be always_inline.
+ // The big assumption here is that srcRegion is one of init, copy or dealloc
+ // regions of a omp::PrivateClauseop. Accordingly, the return type is assumed
+ // to either be the same as the types of the two arguments of the region (for
+ // init and copy regions) or void as would be the case for dealloc regions.
+ LLVM::LLVMFuncOp createFuncOpForRegion(Location loc, ModuleOp mod,
+ Region &srcRegion,
+ llvm::StringRef funcName,
+ IRRewriter &rewriter,
+ bool returnsValue = false) {
+
+ OpBuilder::InsertionGuard guard(rewriter);
+ rewriter.setInsertionPoint(mod.getBody(), mod.getBody()->end());
+ Region clonedRegion;
+ IRMapping mapper;
+ srcRegion.cloneInto(&clonedRegion, mapper);
+
+ SmallVector<Type> paramTypes;
+ llvm::copy(srcRegion.getArgumentTypes(), std::back_inserter(paramTypes));
+ Type resultType = returnsValue
+ ? srcRegion.getArgument(0).getType()
+ : LLVM::LLVMVoidType::get(rewriter.getContext());
+ LLVM::LLVMFunctionType funcType =
+ LLVM::LLVMFunctionType::get(resultType, paramTypes);
+
+ LLVM::LLVMFuncOp func =
+ LLVM::LLVMFuncOp::create(rewriter, loc, funcName, funcType);
+ func.setAlwaysInline(true);
+ rewriter.inlineRegionBefore(clonedRegion, func.getRegion(),
+ func.getRegion().end());
+ for (auto &block : func.getRegion().getBlocks()) {
+ if (isa<omp::YieldOp>(block.getTerminator())) {
+ omp::YieldOp yieldOp = cast<omp::YieldOp>(block.getTerminator());
+ rewriter.setInsertionPoint(yieldOp);
+ rewriter.replaceOpWithNewOp<LLVM::ReturnOp>(yieldOp, TypeRange(),
+ yieldOp.getOperands());
+ }
+ }
+ return func;
+ }
+};
+} // namespace