summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp')
-rw-r--r--mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp123
1 files changed, 96 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
index 110873011fe3..c0be9e919d2f 100644
--- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
+++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp
@@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context,
return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs);
}
+static DenseBoolArrayAttr
+makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) {
+ return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray);
+}
+
namespace {
struct MemRefPointerLikeModel
: public PointerLikeType::ExternalModel<MemRefPointerLikeModel,
@@ -434,6 +439,45 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op,
}
//===----------------------------------------------------------------------===//
+// Parser and printer for Order Clause
+//===----------------------------------------------------------------------===//
+
+// order ::= `order` `(` [order-modifier ':'] concurrent `)`
+// order-modifier ::= reproducible | unconstrained
+static ParseResult parseOrderClause(OpAsmParser &parser,
+ ClauseOrderKindAttr &kindAttr,
+ OrderModifierAttr &modifierAttr) {
+ StringRef enumStr;
+ SMLoc loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&enumStr))
+ return failure();
+ if (std::optional<OrderModifier> enumValue =
+ symbolizeOrderModifier(enumStr)) {
+ modifierAttr = OrderModifierAttr::get(parser.getContext(), *enumValue);
+ if (parser.parseOptionalColon())
+ return failure();
+ loc = parser.getCurrentLocation();
+ if (parser.parseKeyword(&enumStr))
+ return failure();
+ }
+ if (std::optional<ClauseOrderKind> enumValue =
+ symbolizeClauseOrderKind(enumStr)) {
+ kindAttr = ClauseOrderKindAttr::get(parser.getContext(), *enumValue);
+ return success();
+ }
+ return parser.emitError(loc, "invalid clause value: '") << enumStr << "'";
+}
+
+static void printOrderClause(OpAsmPrinter &p, Operation *op,
+ ClauseOrderKindAttr kindAttr,
+ OrderModifierAttr modifierAttr) {
+ if (modifierAttr)
+ p << stringifyOrderModifier(modifierAttr.getValue()) << ":";
+ if (kindAttr)
+ p << stringifyClauseOrderKind(kindAttr.getValue());
+}
+
+//===----------------------------------------------------------------------===//
// Parser, printer and verifier for ReductionVarList
//===----------------------------------------------------------------------===//
@@ -460,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs(
return success();
})))
return failure();
- isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec);
+ isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
auto *argsBegin = regionPrivateArgs.begin();
MutableArrayRef argsSubrange(argsBegin + regionArgOffset,
@@ -552,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
mlir::SmallVector<bool> isByRefVec;
isByRefVec.resize(privateVarTypes.size(), false);
DenseBoolArrayAttr isByRef =
- DenseBoolArrayAttr::get(op->getContext(), isByRefVec);
+ makeDenseBoolArrayAttr(op->getContext(), isByRefVec);
printClauseWithRegionArgs(p, op, argsSubrange, "private",
privateVarOperands, privateVarTypes, isByRef,
@@ -568,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region &region,
static ParseResult
parseReductionVarList(OpAsmParser &parser,
SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands,
- SmallVectorImpl<Type> &types,
+ SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef,
ArrayAttr &redcuctionSymbols) {
SmallVector<SymbolRefAttr> reductionVec;
+ SmallVector<bool> isByRefVec;
if (failed(parser.parseCommaSeparatedList([&]() {
+ ParseResult optionalByref = parser.parseOptionalKeyword("byref");
if (parser.parseAttribute(reductionVec.emplace_back()) ||
parser.parseArrow() ||
parser.parseOperand(operands.emplace_back()) ||
parser.parseColonType(types.emplace_back()))
return failure();
+ isByRefVec.push_back(optionalByref.succeeded());
return success();
})))
return failure();
+ isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec);
SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end());
redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions);
return success();
@@ -589,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser,
static void printReductionVarList(OpAsmPrinter &p, Operation *op,
OperandRange reductionVars,
TypeRange reductionTypes,
+ std::optional<DenseBoolArrayAttr> isByRef,
std::optional<ArrayAttr> reductions) {
- for (unsigned i = 0, e = reductions->size(); i < e; ++i) {
+ auto getByRef = [&](unsigned i) -> const char * {
+ if (!isByRef || !*isByRef)
+ return "";
+ assert(isByRef->empty() || i < isByRef->size());
+ if (!isByRef->empty() && (*isByRef)[i])
+ return "byref ";
+ return "";
+ };
+
+ for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) {
if (i != 0)
p << ", ";
- p << (*reductions)[i] << " -> " << reductionVars[i] << " : "
+ p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : "
<< reductionVars[i].getType();
}
}
@@ -602,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op,
static LogicalResult
verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions,
OperandRange reductionVars,
- std::optional<ArrayRef<bool>> byRef = std::nullopt) {
+ std::optional<ArrayRef<bool>> byRef) {
if (!reductionVars.empty()) {
if (!reductions || reductions->size() != reductionVars.size())
return op->emitOpError()
<< "expected as many reduction symbol references "
"as reduction variables";
- if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op))
- assert(byRef);
- else
- assert(!byRef); // TODO: support byref reductions on other operations
if (byRef && byRef->size() != reductionVars.size())
return op->emitError() << "expected as many reduction variable by "
"reference attributes as reduction variables";
@@ -1453,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state,
ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar,
clauses.allocateVars, clauses.allocatorVars,
clauses.reductionVars,
- DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.procBindKindAttr, clauses.privateVars,
makeArrayAttr(ctx, clauses.privatizers));
@@ -1551,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state,
clauses.numTeamsUpperVar, clauses.ifVar,
clauses.threadLimitVar, clauses.allocateVars,
clauses.allocatorVars, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols));
}
@@ -1582,7 +1637,8 @@ LogicalResult TeamsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- return verifyReductionVarList(*this, getReductions(), getReductionVars());
+ return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1594,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state,
MLIRContext *ctx = builder.getContext();
// TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers.
SectionsOp::build(builder, state, clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.allocateVars, clauses.allocatorVars,
clauses.nowaitAttr);
@@ -1604,7 +1661,8 @@ LogicalResult SectionsOp::verify() {
return emitError(
"expected equal sizes for allocate and allocator variables");
- return verifyReductionVarList(*this, getReductions(), getReductionVars());
+ return verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref());
}
LogicalResult SectionsOp::verifyRegions() {
@@ -1682,7 +1740,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
/*reductions=*/nullptr, /*schedule_val=*/nullptr,
/*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr,
/*simd_modifier=*/false, /*nowait=*/false,
- /*ordered_val=*/nullptr, /*order_val=*/nullptr);
+ /*ordered_val=*/nullptr, /*order_val=*/nullptr,
+ /*order_modifier=*/nullptr);
state.addAttributes(attributes);
}
@@ -1693,11 +1752,12 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state,
// privatizers.
WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars,
clauses.reductionVars,
- DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef),
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols),
clauses.scheduleValAttr, clauses.scheduleChunkVar,
clauses.scheduleModAttr, clauses.scheduleSimdAttr,
- clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr);
+ clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr,
+ clauses.orderModAttr);
}
LogicalResult WsloopOp::verify() {
@@ -1726,8 +1786,8 @@ void SimdOp::build(OpBuilder &builder, OperationState &state,
// privatizers, reductionDeclSymbols.
SimdOp::build(builder, state, clauses.alignedVars,
makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar,
- clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr,
- clauses.safelenAttr);
+ clauses.nontemporalVars, clauses.orderAttr,
+ clauses.orderModAttr, clauses.simdlenAttr, clauses.safelenAttr);
}
LogicalResult SimdOp::verify() {
@@ -1762,7 +1822,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state,
// TODO Store clauses in op: privateVars, privatizers.
DistributeOp::build(builder, state, clauses.distScheduleStaticAttr,
clauses.distScheduleChunkSizeVar, clauses.allocateVars,
- clauses.allocatorVars, clauses.orderAttr);
+ clauses.allocatorVars, clauses.orderAttr,
+ clauses.orderModAttr);
}
LogicalResult DistributeOp::verify() {
@@ -1892,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state,
TaskOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar,
makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars,
clauses.allocateVars, clauses.allocatorVars);
@@ -1903,7 +1965,8 @@ LogicalResult TaskOp::verify() {
return failed(verifyDependVars)
? verifyDependVars
: verifyReductionVarList(*this, getInReductions(),
- getInReductionVars());
+ getInReductionVars(),
+ getInReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1913,14 +1976,17 @@ LogicalResult TaskOp::verify() {
void TaskgroupOp::build(OpBuilder &builder, OperationState &state,
const TaskgroupClauseOps &clauses) {
MLIRContext *ctx = builder.getContext();
- TaskgroupOp::build(builder, state, clauses.taskReductionVars,
- makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
- clauses.allocateVars, clauses.allocatorVars);
+ TaskgroupOp::build(
+ builder, state, clauses.taskReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef),
+ makeArrayAttr(ctx, clauses.taskReductionDeclSymbols),
+ clauses.allocateVars, clauses.allocatorVars);
}
LogicalResult TaskgroupOp::verify() {
return verifyReductionVarList(*this, getTaskReductions(),
- getTaskReductionVars());
+ getTaskReductionVars(),
+ getTaskReductionVarsByref());
}
//===----------------------------------------------------------------------===//
@@ -1934,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state,
TaskloopOp::build(
builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr,
clauses.mergeableAttr, clauses.inReductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef),
makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars,
+ makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef),
makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar,
clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar,
clauses.numTasksVar, clauses.nogroupAttr);
@@ -1952,10 +2020,11 @@ LogicalResult TaskloopOp::verify() {
if (getAllocateVars().size() != getAllocatorsVars().size())
return emitError(
"expected equal sizes for allocate and allocator variables");
- if (failed(
- verifyReductionVarList(*this, getReductions(), getReductionVars())) ||
+ if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(),
+ getReductionVarsByref())) ||
failed(verifyReductionVarList(*this, getInReductions(),
- getInReductionVars())))
+ getInReductionVars(),
+ getInReductionVarsByref())))
return failure();
if (!getReductionVars().empty() && getNogroup())