diff options
Diffstat (limited to 'mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp')
| -rw-r--r-- | mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp | 123 |
1 files changed, 96 insertions, 27 deletions
diff --git a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp index 110873011fe3..c0be9e919d2f 100644 --- a/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp +++ b/mlir/lib/Dialect/OpenMP/IR/OpenMPDialect.cpp @@ -48,6 +48,11 @@ static ArrayAttr makeArrayAttr(MLIRContext *context, return attrs.empty() ? nullptr : ArrayAttr::get(context, attrs); } +static DenseBoolArrayAttr +makeDenseBoolArrayAttr(MLIRContext *ctx, const ArrayRef<bool> boolArray) { + return boolArray.empty() ? nullptr : DenseBoolArrayAttr::get(ctx, boolArray); +} + namespace { struct MemRefPointerLikeModel : public PointerLikeType::ExternalModel<MemRefPointerLikeModel, @@ -434,6 +439,45 @@ static void printScheduleClause(OpAsmPrinter &p, Operation *op, } //===----------------------------------------------------------------------===// +// Parser and printer for Order Clause +//===----------------------------------------------------------------------===// + +// order ::= `order` `(` [order-modifier ':'] concurrent `)` +// order-modifier ::= reproducible | unconstrained +static ParseResult parseOrderClause(OpAsmParser &parser, + ClauseOrderKindAttr &kindAttr, + OrderModifierAttr &modifierAttr) { + StringRef enumStr; + SMLoc loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&enumStr)) + return failure(); + if (std::optional<OrderModifier> enumValue = + symbolizeOrderModifier(enumStr)) { + modifierAttr = OrderModifierAttr::get(parser.getContext(), *enumValue); + if (parser.parseOptionalColon()) + return failure(); + loc = parser.getCurrentLocation(); + if (parser.parseKeyword(&enumStr)) + return failure(); + } + if (std::optional<ClauseOrderKind> enumValue = + symbolizeClauseOrderKind(enumStr)) { + kindAttr = ClauseOrderKindAttr::get(parser.getContext(), *enumValue); + return success(); + } + return parser.emitError(loc, "invalid clause value: '") << enumStr << "'"; +} + +static void printOrderClause(OpAsmPrinter &p, Operation *op, + ClauseOrderKindAttr kindAttr, + OrderModifierAttr modifierAttr) { + if (modifierAttr) + p << stringifyOrderModifier(modifierAttr.getValue()) << ":"; + if (kindAttr) + p << stringifyClauseOrderKind(kindAttr.getValue()); +} + +//===----------------------------------------------------------------------===// // Parser, printer and verifier for ReductionVarList //===----------------------------------------------------------------------===// @@ -460,7 +504,7 @@ static ParseResult parseClauseWithRegionArgs( return success(); }))) return failure(); - isByRef = DenseBoolArrayAttr::get(parser.getContext(), isByRefVec); + isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); auto *argsBegin = regionPrivateArgs.begin(); MutableArrayRef argsSubrange(argsBegin + regionArgOffset, @@ -552,7 +596,7 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, mlir::SmallVector<bool> isByRefVec; isByRefVec.resize(privateVarTypes.size(), false); DenseBoolArrayAttr isByRef = - DenseBoolArrayAttr::get(op->getContext(), isByRefVec); + makeDenseBoolArrayAttr(op->getContext(), isByRefVec); printClauseWithRegionArgs(p, op, argsSubrange, "private", privateVarOperands, privateVarTypes, isByRef, @@ -568,18 +612,22 @@ static void printParallelRegion(OpAsmPrinter &p, Operation *op, Region ®ion, static ParseResult parseReductionVarList(OpAsmParser &parser, SmallVectorImpl<OpAsmParser::UnresolvedOperand> &operands, - SmallVectorImpl<Type> &types, + SmallVectorImpl<Type> &types, DenseBoolArrayAttr &isByRef, ArrayAttr &redcuctionSymbols) { SmallVector<SymbolRefAttr> reductionVec; + SmallVector<bool> isByRefVec; if (failed(parser.parseCommaSeparatedList([&]() { + ParseResult optionalByref = parser.parseOptionalKeyword("byref"); if (parser.parseAttribute(reductionVec.emplace_back()) || parser.parseArrow() || parser.parseOperand(operands.emplace_back()) || parser.parseColonType(types.emplace_back())) return failure(); + isByRefVec.push_back(optionalByref.succeeded()); return success(); }))) return failure(); + isByRef = makeDenseBoolArrayAttr(parser.getContext(), isByRefVec); SmallVector<Attribute> reductions(reductionVec.begin(), reductionVec.end()); redcuctionSymbols = ArrayAttr::get(parser.getContext(), reductions); return success(); @@ -589,11 +637,21 @@ parseReductionVarList(OpAsmParser &parser, static void printReductionVarList(OpAsmPrinter &p, Operation *op, OperandRange reductionVars, TypeRange reductionTypes, + std::optional<DenseBoolArrayAttr> isByRef, std::optional<ArrayAttr> reductions) { - for (unsigned i = 0, e = reductions->size(); i < e; ++i) { + auto getByRef = [&](unsigned i) -> const char * { + if (!isByRef || !*isByRef) + return ""; + assert(isByRef->empty() || i < isByRef->size()); + if (!isByRef->empty() && (*isByRef)[i]) + return "byref "; + return ""; + }; + + for (unsigned i = 0, e = reductionVars.size(); i < e; ++i) { if (i != 0) p << ", "; - p << (*reductions)[i] << " -> " << reductionVars[i] << " : " + p << getByRef(i) << (*reductions)[i] << " -> " << reductionVars[i] << " : " << reductionVars[i].getType(); } } @@ -602,16 +660,12 @@ static void printReductionVarList(OpAsmPrinter &p, Operation *op, static LogicalResult verifyReductionVarList(Operation *op, std::optional<ArrayAttr> reductions, OperandRange reductionVars, - std::optional<ArrayRef<bool>> byRef = std::nullopt) { + std::optional<ArrayRef<bool>> byRef) { if (!reductionVars.empty()) { if (!reductions || reductions->size() != reductionVars.size()) return op->emitOpError() << "expected as many reduction symbol references " "as reduction variables"; - if (mlir::isa<omp::WsloopOp, omp::ParallelOp>(op)) - assert(byRef); - else - assert(!byRef); // TODO: support byref reductions on other operations if (byRef && byRef->size() != reductionVars.size()) return op->emitError() << "expected as many reduction variable by " "reference attributes as reduction variables"; @@ -1453,7 +1507,7 @@ void ParallelOp::build(OpBuilder &builder, OperationState &state, ParallelOp::build(builder, state, clauses.ifVar, clauses.numThreadsVar, clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars, - DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef), + makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef), makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.procBindKindAttr, clauses.privateVars, makeArrayAttr(ctx, clauses.privatizers)); @@ -1551,6 +1605,7 @@ void TeamsOp::build(OpBuilder &builder, OperationState &state, clauses.numTeamsUpperVar, clauses.ifVar, clauses.threadLimitVar, clauses.allocateVars, clauses.allocatorVars, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef), makeArrayAttr(ctx, clauses.reductionDeclSymbols)); } @@ -1582,7 +1637,8 @@ LogicalResult TeamsOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - return verifyReductionVarList(*this, getReductions(), getReductionVars()); + return verifyReductionVarList(*this, getReductions(), getReductionVars(), + getReductionVarsByref()); } //===----------------------------------------------------------------------===// @@ -1594,6 +1650,7 @@ void SectionsOp::build(OpBuilder &builder, OperationState &state, MLIRContext *ctx = builder.getContext(); // TODO Store clauses in op: reductionByRefAttr, privateVars, privatizers. SectionsOp::build(builder, state, clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef), makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.allocateVars, clauses.allocatorVars, clauses.nowaitAttr); @@ -1604,7 +1661,8 @@ LogicalResult SectionsOp::verify() { return emitError( "expected equal sizes for allocate and allocator variables"); - return verifyReductionVarList(*this, getReductions(), getReductionVars()); + return verifyReductionVarList(*this, getReductions(), getReductionVars(), + getReductionVarsByref()); } LogicalResult SectionsOp::verifyRegions() { @@ -1682,7 +1740,8 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, /*reductions=*/nullptr, /*schedule_val=*/nullptr, /*schedule_chunk_var=*/nullptr, /*schedule_modifier=*/nullptr, /*simd_modifier=*/false, /*nowait=*/false, - /*ordered_val=*/nullptr, /*order_val=*/nullptr); + /*ordered_val=*/nullptr, /*order_val=*/nullptr, + /*order_modifier=*/nullptr); state.addAttributes(attributes); } @@ -1693,11 +1752,12 @@ void WsloopOp::build(OpBuilder &builder, OperationState &state, // privatizers. WsloopOp::build(builder, state, clauses.linearVars, clauses.linearStepVars, clauses.reductionVars, - DenseBoolArrayAttr::get(ctx, clauses.reduceVarByRef), + makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef), makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.scheduleValAttr, clauses.scheduleChunkVar, clauses.scheduleModAttr, clauses.scheduleSimdAttr, - clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr); + clauses.nowaitAttr, clauses.orderedAttr, clauses.orderAttr, + clauses.orderModAttr); } LogicalResult WsloopOp::verify() { @@ -1726,8 +1786,8 @@ void SimdOp::build(OpBuilder &builder, OperationState &state, // privatizers, reductionDeclSymbols. SimdOp::build(builder, state, clauses.alignedVars, makeArrayAttr(ctx, clauses.alignmentAttrs), clauses.ifVar, - clauses.nontemporalVars, clauses.orderAttr, clauses.simdlenAttr, - clauses.safelenAttr); + clauses.nontemporalVars, clauses.orderAttr, + clauses.orderModAttr, clauses.simdlenAttr, clauses.safelenAttr); } LogicalResult SimdOp::verify() { @@ -1762,7 +1822,8 @@ void DistributeOp::build(OpBuilder &builder, OperationState &state, // TODO Store clauses in op: privateVars, privatizers. DistributeOp::build(builder, state, clauses.distScheduleStaticAttr, clauses.distScheduleChunkSizeVar, clauses.allocateVars, - clauses.allocatorVars, clauses.orderAttr); + clauses.allocatorVars, clauses.orderAttr, + clauses.orderModAttr); } LogicalResult DistributeOp::verify() { @@ -1892,6 +1953,7 @@ void TaskOp::build(OpBuilder &builder, OperationState &state, TaskOp::build( builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef), makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.priorityVar, makeArrayAttr(ctx, clauses.dependTypeAttrs), clauses.dependVars, clauses.allocateVars, clauses.allocatorVars); @@ -1903,7 +1965,8 @@ LogicalResult TaskOp::verify() { return failed(verifyDependVars) ? verifyDependVars : verifyReductionVarList(*this, getInReductions(), - getInReductionVars()); + getInReductionVars(), + getInReductionVarsByref()); } //===----------------------------------------------------------------------===// @@ -1913,14 +1976,17 @@ LogicalResult TaskOp::verify() { void TaskgroupOp::build(OpBuilder &builder, OperationState &state, const TaskgroupClauseOps &clauses) { MLIRContext *ctx = builder.getContext(); - TaskgroupOp::build(builder, state, clauses.taskReductionVars, - makeArrayAttr(ctx, clauses.taskReductionDeclSymbols), - clauses.allocateVars, clauses.allocatorVars); + TaskgroupOp::build( + builder, state, clauses.taskReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.taskReductionVarsByRef), + makeArrayAttr(ctx, clauses.taskReductionDeclSymbols), + clauses.allocateVars, clauses.allocatorVars); } LogicalResult TaskgroupOp::verify() { return verifyReductionVarList(*this, getTaskReductions(), - getTaskReductionVars()); + getTaskReductionVars(), + getTaskReductionVarsByref()); } //===----------------------------------------------------------------------===// @@ -1934,7 +2000,9 @@ void TaskloopOp::build(OpBuilder &builder, OperationState &state, TaskloopOp::build( builder, state, clauses.ifVar, clauses.finalVar, clauses.untiedAttr, clauses.mergeableAttr, clauses.inReductionVars, + makeDenseBoolArrayAttr(ctx, clauses.inReductionVarsByRef), makeArrayAttr(ctx, clauses.inReductionDeclSymbols), clauses.reductionVars, + makeDenseBoolArrayAttr(ctx, clauses.reductionVarsByRef), makeArrayAttr(ctx, clauses.reductionDeclSymbols), clauses.priorityVar, clauses.allocateVars, clauses.allocatorVars, clauses.grainsizeVar, clauses.numTasksVar, clauses.nogroupAttr); @@ -1952,10 +2020,11 @@ LogicalResult TaskloopOp::verify() { if (getAllocateVars().size() != getAllocatorsVars().size()) return emitError( "expected equal sizes for allocate and allocator variables"); - if (failed( - verifyReductionVarList(*this, getReductions(), getReductionVars())) || + if (failed(verifyReductionVarList(*this, getReductions(), getReductionVars(), + getReductionVarsByref())) || failed(verifyReductionVarList(*this, getInReductions(), - getInReductionVars()))) + getInReductionVars(), + getInReductionVarsByref()))) return failure(); if (!getReductionVars().empty() && getNogroup()) |
