summaryrefslogtreecommitdiff
path: root/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Lower/OpenMP/ClauseProcessor.cpp')
-rw-r--r--flang/lib/Lower/OpenMP/ClauseProcessor.cpp101
1 files changed, 99 insertions, 2 deletions
diff --git a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
index d289f2fdfab2..f78cd0f9df1a 100644
--- a/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
+++ b/flang/lib/Lower/OpenMP/ClauseProcessor.cpp
@@ -16,6 +16,7 @@
#include "flang/Lower/PFTBuilder.h"
#include "flang/Parser/tools.h"
#include "flang/Semantics/tools.h"
+#include "llvm/Frontend/OpenMP/OMPIRBuilder.h"
namespace Fortran {
namespace lower {
@@ -317,6 +318,20 @@ bool ClauseProcessor::processDeviceType(
return false;
}
+bool ClauseProcessor::processDistSchedule(
+ lower::StatementContext &stmtCtx,
+ mlir::omp::DistScheduleClauseOps &result) const {
+ if (auto *clause = findUniqueClause<omp::clause::DistSchedule>()) {
+ result.distScheduleStaticAttr = converter.getFirOpBuilder().getUnitAttr();
+ const auto &chunkSize = std::get<std::optional<ExprTy>>(clause->t);
+ if (chunkSize)
+ result.distScheduleChunkSizeVar =
+ fir::getBase(converter.genExprValue(*chunkSize, stmtCtx));
+ return true;
+ }
+ return false;
+}
+
bool ClauseProcessor::processFinal(lower::StatementContext &stmtCtx,
mlir::omp::FinalClauseOps &result) const {
const parser::CharBlock *source = nullptr;
@@ -379,6 +394,28 @@ bool ClauseProcessor::processNumThreads(
return false;
}
+bool ClauseProcessor::processOrder(mlir::omp::OrderClauseOps &result) const {
+ using Order = omp::clause::Order;
+ if (auto *clause = findUniqueClause<Order>()) {
+ fir::FirOpBuilder &firOpBuilder = converter.getFirOpBuilder();
+ result.orderAttr = mlir::omp::ClauseOrderKindAttr::get(
+ firOpBuilder.getContext(), mlir::omp::ClauseOrderKind::Concurrent);
+ const auto &modifier =
+ std::get<std::optional<Order::OrderModifier>>(clause->t);
+ if (modifier && *modifier == Order::OrderModifier::Unconstrained) {
+ result.orderModAttr = mlir::omp::OrderModifierAttr::get(
+ firOpBuilder.getContext(), mlir::omp::OrderModifier::unconstrained);
+ } else {
+ // "If order-modifier is not unconstrained, the behavior is as if the
+ // reproducible modifier is present."
+ result.orderModAttr = mlir::omp::OrderModifierAttr::get(
+ firOpBuilder.getContext(), mlir::omp::OrderModifier::reproducible);
+ }
+ return true;
+ }
+ return false;
+}
+
bool ClauseProcessor::processOrdered(
mlir::omp::OrderedClauseOps &result) const {
if (auto *clause = findUniqueClause<omp::clause::Ordered>()) {
@@ -500,6 +537,65 @@ bool ClauseProcessor::processUntied(mlir::omp::UntiedClauseOps &result) const {
//===----------------------------------------------------------------------===//
// ClauseProcessor repeatable clauses
//===----------------------------------------------------------------------===//
+static llvm::StringMap<bool> getTargetFeatures(mlir::ModuleOp module) {
+ llvm::StringMap<bool> featuresMap;
+ llvm::SmallVector<llvm::StringRef> targetFeaturesVec;
+ if (mlir::LLVM::TargetFeaturesAttr features =
+ fir::getTargetFeatures(module)) {
+ llvm::ArrayRef<mlir::StringAttr> featureAttrs = features.getFeatures();
+ for (auto &featureAttr : featureAttrs) {
+ llvm::StringRef featureKeyString = featureAttr.strref();
+ featuresMap[featureKeyString.substr(1)] = (featureKeyString[0] == '+');
+ }
+ }
+ return featuresMap;
+}
+
+static void
+addAlignedClause(lower::AbstractConverter &converter,
+ const omp::clause::Aligned &clause,
+ llvm::SmallVectorImpl<mlir::Value> &alignedVars,
+ llvm::SmallVectorImpl<mlir::Attribute> &alignmentAttrs) {
+ using Aligned = omp::clause::Aligned;
+ lower::StatementContext stmtCtx;
+ mlir::IntegerAttr alignmentValueAttr;
+ int64_t alignment = 0;
+ fir::FirOpBuilder &builder = converter.getFirOpBuilder();
+
+ if (auto &alignmentValueParserExpr =
+ std::get<std::optional<Aligned::Alignment>>(clause.t)) {
+ mlir::Value operand = fir::getBase(
+ converter.genExprValue(*alignmentValueParserExpr, stmtCtx));
+ alignment = *fir::getIntIfConstant(operand);
+ } else {
+ llvm::StringMap<bool> featuresMap = getTargetFeatures(builder.getModule());
+ llvm::Triple triple = fir::getTargetTriple(builder.getModule());
+ alignment =
+ llvm::OpenMPIRBuilder::getOpenMPDefaultSimdAlign(triple, featuresMap);
+ }
+
+ // The default alignment for some targets is equal to 0.
+ // Do not generate alignment assumption if alignment is less than or equal to
+ // 0.
+ if (alignment > 0) {
+ auto &objects = std::get<omp::ObjectList>(clause.t);
+ if (!objects.empty())
+ genObjectList(objects, converter, alignedVars);
+ alignmentValueAttr = builder.getI64IntegerAttr(alignment);
+ // All the list items in a aligned clause will have same alignment
+ for (std::size_t i = 0; i < objects.size(); i++)
+ alignmentAttrs.push_back(alignmentValueAttr);
+ }
+}
+
+bool ClauseProcessor::processAligned(
+ mlir::omp::AlignedClauseOps &result) const {
+ return findRepeatableClause<omp::clause::Aligned>(
+ [&](const omp::clause::Aligned &clause, const parser::CharBlock &) {
+ addAlignedClause(converter, clause, result.alignedVars,
+ result.alignmentAttrs);
+ });
+}
bool ClauseProcessor::processAllocate(
mlir::omp::AllocateClauseOps &result) const {
@@ -655,7 +751,7 @@ createCopyFunc(mlir::Location loc, lower::AbstractConverter &converter,
auto declSrc = builder.create<hlfir::DeclareOp>(
loc, funcOp.getArgument(1), copyFuncName + "_src", shape, typeparams,
/*dummy_scope=*/nullptr, attrs);
- converter.copyVar(loc, declDst.getBase(), declSrc.getBase());
+ converter.copyVar(loc, declDst.getBase(), declSrc.getBase(), varAttrs);
builder.create<mlir::func::ReturnOp>(loc);
return funcOp;
}
@@ -931,7 +1027,8 @@ bool ClauseProcessor::processReduction(
// Copy local lists into the output.
llvm::copy(reductionVars, std::back_inserter(result.reductionVars));
- llvm::copy(reduceVarByRef, std::back_inserter(result.reduceVarByRef));
+ llvm::copy(reduceVarByRef,
+ std::back_inserter(result.reductionVarsByRef));
llvm::copy(reductionDeclSymbols,
std::back_inserter(result.reductionDeclSymbols));