diff options
Diffstat (limited to 'flang/lib/Optimizer/Transforms/LoopVersioning.cpp')
| -rw-r--r-- | flang/lib/Optimizer/Transforms/LoopVersioning.cpp | 116 |
1 files changed, 89 insertions, 27 deletions
diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp index adc39861840a..b534ec160ce2 100644 --- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp +++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp @@ -145,11 +145,45 @@ struct ArgsUsageInLoop { }; } // namespace -static fir::SequenceType getAsSequenceType(mlir::Value *v) { - mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType())); +static fir::SequenceType getAsSequenceType(mlir::Value v) { + mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType())); return mlir::dyn_cast<fir::SequenceType>(argTy); } +/// Return the rank and the element size (in bytes) of the given +/// value \p v. If it is not an array or the element type is not +/// supported, then return <0, 0>. Only trivial data types +/// are currently supported. +/// When \p isArgument is true, \p v is assumed to be a function +/// argument. If \p v's type does not look like a type of an assumed +/// shape array, then the function returns <0, 0>. +/// When \p isArgument is false, array types with known innermost +/// dimension are allowed to proceed. +static std::pair<unsigned, size_t> +getRankAndElementSize(const fir::KindMapping &kindMap, + const mlir::DataLayout &dl, mlir::Value v, + bool isArgument = false) { + if (auto seqTy = getAsSequenceType(v)) { + unsigned rank = seqTy.getDimension(); + if (rank > 0 && + (!isArgument || + seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) { + size_t typeSize = 0; + mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType()); + if (fir::isa_trivial(elementType)) { + auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash( + v.getLoc(), elementType, dl, kindMap); + typeSize = llvm::alignTo(eleSize, eleAlign); + } + if (typeSize) + return {rank, typeSize}; + } + } + + LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n'); + return {0, 0}; +} + /// if a value comes from a fir.declare, follow it to the original source, /// otherwise return the value static mlir::Value unwrapFirDeclare(mlir::Value val) { @@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) { return val; } +/// Return true, if \p rebox operation keeps the input array +/// continuous in the innermost dimension, if it is initially continuous +/// in the innermost dimension. +static bool reboxPreservesContinuity(fir::ReboxOp rebox) { + // If slicing is not involved, then the rebox does not affect + // the continuity of the array. + auto sliceArg = rebox.getSlice(); + if (!sliceArg) + return true; + + // A slice with step=1 in the innermost dimension preserves + // the continuity of the array in the innermost dimension. + if (auto sliceOp = + mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) { + if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) { + auto triples = sliceOp.getTriples(); + if (triples.size() > 2) + if (auto innermostStep = fir::getIntIfConstant(triples[2])) + if (*innermostStep == 1) + return true; + } + + LLVM_DEBUG(llvm::dbgs() + << "REBOX with slicing may produce non-contiguous array: " + << sliceOp << '\n' + << rebox << '\n'); + return false; + } + + LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n' + << rebox << '\n'); + return false; +} + /// if a value comes from a fir.rebox, follow the rebox to the original source, /// of the value, otherwise return the value static mlir::Value unwrapReboxOp(mlir::Value val) { - // don't support reboxes of reboxes - if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) + while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) { + if (!reboxPreservesContinuity(rebox)) + break; val = rebox.getBox(); + } return val; } @@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() { continue; } - if (auto seqTy = getAsSequenceType(&arg)) { - unsigned rank = seqTy.getDimension(); - if (rank > 0 && - seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) { - size_t typeSize = 0; - mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType()); - if (mlir::isa<mlir::FloatType>(elementType) || - mlir::isa<mlir::IntegerType>(elementType) || - mlir::isa<mlir::ComplexType>(elementType)) { - auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash( - arg.getLoc(), elementType, *dl, kindMap); - typeSize = llvm::alignTo(eleSize, eleAlign); - } - if (typeSize) - argsOfInterest.push_back({arg, typeSize, rank, {}}); - else - LLVM_DEBUG(llvm::dbgs() << "Type not supported\n"); - } - } + auto [rank, typeSize] = + getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true); + if (rank != 0 && typeSize != 0) + argsOfInterest.push_back({arg, typeSize, rank, {}}); } if (argsOfInterest.empty()) { @@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() { if (arrayCoor.getSlice()) argsInLoop.cannotTransform.insert(a.arg); + // We need to compute the rank and element size + // based on the operand, not the original argument, + // because array slicing may affect it. + std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg); + if (a.rank == 0 || a.size == 0) + argsInLoop.cannotTransform.insert(a.arg); + if (argsInLoop.cannotTransform.contains(a.arg)) { // Remove any previously recorded usage, if any. argsInLoop.usageInfo.erase(a.arg); @@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() { mlir::Location loc = builder.getUnknownLoc(); mlir::IndexType idxTy = builder.getIndexType(); - LLVM_DEBUG(llvm::dbgs() << "Module Before transformation:"); - LLVM_DEBUG(module->dump()); + LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n"); + LLVM_DEBUG(func->dump()); LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size() << "\n"); @@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() { } } - LLVM_DEBUG(llvm::dbgs() << "After transform:\n"); - LLVM_DEBUG(module->dump()); + LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n"); + LLVM_DEBUG(func->dump()); LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n"); } |
