summaryrefslogtreecommitdiff
path: root/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'flang/lib/Optimizer/Transforms/LoopVersioning.cpp')
-rw-r--r--flang/lib/Optimizer/Transforms/LoopVersioning.cpp116
1 files changed, 89 insertions, 27 deletions
diff --git a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
index adc39861840a..b534ec160ce2 100644
--- a/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
+++ b/flang/lib/Optimizer/Transforms/LoopVersioning.cpp
@@ -145,11 +145,45 @@ struct ArgsUsageInLoop {
};
} // namespace
-static fir::SequenceType getAsSequenceType(mlir::Value *v) {
- mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v->getType()));
+static fir::SequenceType getAsSequenceType(mlir::Value v) {
+ mlir::Type argTy = fir::unwrapPassByRefType(fir::unwrapRefType(v.getType()));
return mlir::dyn_cast<fir::SequenceType>(argTy);
}
+/// Return the rank and the element size (in bytes) of the given
+/// value \p v. If it is not an array or the element type is not
+/// supported, then return <0, 0>. Only trivial data types
+/// are currently supported.
+/// When \p isArgument is true, \p v is assumed to be a function
+/// argument. If \p v's type does not look like a type of an assumed
+/// shape array, then the function returns <0, 0>.
+/// When \p isArgument is false, array types with known innermost
+/// dimension are allowed to proceed.
+static std::pair<unsigned, size_t>
+getRankAndElementSize(const fir::KindMapping &kindMap,
+ const mlir::DataLayout &dl, mlir::Value v,
+ bool isArgument = false) {
+ if (auto seqTy = getAsSequenceType(v)) {
+ unsigned rank = seqTy.getDimension();
+ if (rank > 0 &&
+ (!isArgument ||
+ seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent())) {
+ size_t typeSize = 0;
+ mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(v.getType());
+ if (fir::isa_trivial(elementType)) {
+ auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
+ v.getLoc(), elementType, dl, kindMap);
+ typeSize = llvm::alignTo(eleSize, eleAlign);
+ }
+ if (typeSize)
+ return {rank, typeSize};
+ }
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "Unsupported rank/type: " << v << '\n');
+ return {0, 0};
+}
+
/// if a value comes from a fir.declare, follow it to the original source,
/// otherwise return the value
static mlir::Value unwrapFirDeclare(mlir::Value val) {
@@ -160,12 +194,48 @@ static mlir::Value unwrapFirDeclare(mlir::Value val) {
return val;
}
+/// Return true, if \p rebox operation keeps the input array
+/// continuous in the innermost dimension, if it is initially continuous
+/// in the innermost dimension.
+static bool reboxPreservesContinuity(fir::ReboxOp rebox) {
+ // If slicing is not involved, then the rebox does not affect
+ // the continuity of the array.
+ auto sliceArg = rebox.getSlice();
+ if (!sliceArg)
+ return true;
+
+ // A slice with step=1 in the innermost dimension preserves
+ // the continuity of the array in the innermost dimension.
+ if (auto sliceOp =
+ mlir::dyn_cast_or_null<fir::SliceOp>(sliceArg.getDefiningOp())) {
+ if (sliceOp.getFields().empty() && sliceOp.getSubstr().empty()) {
+ auto triples = sliceOp.getTriples();
+ if (triples.size() > 2)
+ if (auto innermostStep = fir::getIntIfConstant(triples[2]))
+ if (*innermostStep == 1)
+ return true;
+ }
+
+ LLVM_DEBUG(llvm::dbgs()
+ << "REBOX with slicing may produce non-contiguous array: "
+ << sliceOp << '\n'
+ << rebox << '\n');
+ return false;
+ }
+
+ LLVM_DEBUG(llvm::dbgs() << "REBOX with unknown slice" << sliceArg << '\n'
+ << rebox << '\n');
+ return false;
+}
+
/// if a value comes from a fir.rebox, follow the rebox to the original source,
/// of the value, otherwise return the value
static mlir::Value unwrapReboxOp(mlir::Value val) {
- // don't support reboxes of reboxes
- if (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>())
+ while (fir::ReboxOp rebox = val.getDefiningOp<fir::ReboxOp>()) {
+ if (!reboxPreservesContinuity(rebox))
+ break;
val = rebox.getBox();
+ }
return val;
}
@@ -257,25 +327,10 @@ void LoopVersioningPass::runOnOperation() {
continue;
}
- if (auto seqTy = getAsSequenceType(&arg)) {
- unsigned rank = seqTy.getDimension();
- if (rank > 0 &&
- seqTy.getShape()[0] == fir::SequenceType::getUnknownExtent()) {
- size_t typeSize = 0;
- mlir::Type elementType = fir::unwrapSeqOrBoxedSeqType(arg.getType());
- if (mlir::isa<mlir::FloatType>(elementType) ||
- mlir::isa<mlir::IntegerType>(elementType) ||
- mlir::isa<mlir::ComplexType>(elementType)) {
- auto [eleSize, eleAlign] = fir::getTypeSizeAndAlignmentOrCrash(
- arg.getLoc(), elementType, *dl, kindMap);
- typeSize = llvm::alignTo(eleSize, eleAlign);
- }
- if (typeSize)
- argsOfInterest.push_back({arg, typeSize, rank, {}});
- else
- LLVM_DEBUG(llvm::dbgs() << "Type not supported\n");
- }
- }
+ auto [rank, typeSize] =
+ getRankAndElementSize(kindMap, *dl, arg, /*isArgument=*/true);
+ if (rank != 0 && typeSize != 0)
+ argsOfInterest.push_back({arg, typeSize, rank, {}});
}
if (argsOfInterest.empty()) {
@@ -326,6 +381,13 @@ void LoopVersioningPass::runOnOperation() {
if (arrayCoor.getSlice())
argsInLoop.cannotTransform.insert(a.arg);
+ // We need to compute the rank and element size
+ // based on the operand, not the original argument,
+ // because array slicing may affect it.
+ std::tie(a.rank, a.size) = getRankAndElementSize(kindMap, *dl, a.arg);
+ if (a.rank == 0 || a.size == 0)
+ argsInLoop.cannotTransform.insert(a.arg);
+
if (argsInLoop.cannotTransform.contains(a.arg)) {
// Remove any previously recorded usage, if any.
argsInLoop.usageInfo.erase(a.arg);
@@ -416,8 +478,8 @@ void LoopVersioningPass::runOnOperation() {
mlir::Location loc = builder.getUnknownLoc();
mlir::IndexType idxTy = builder.getIndexType();
- LLVM_DEBUG(llvm::dbgs() << "Module Before transformation:");
- LLVM_DEBUG(module->dump());
+ LLVM_DEBUG(llvm::dbgs() << "Func Before transformation:\n");
+ LLVM_DEBUG(func->dump());
LLVM_DEBUG(llvm::dbgs() << "loopsOfInterest: " << loopsOfInterest.size()
<< "\n");
@@ -551,8 +613,8 @@ void LoopVersioningPass::runOnOperation() {
}
}
- LLVM_DEBUG(llvm::dbgs() << "After transform:\n");
- LLVM_DEBUG(module->dump());
+ LLVM_DEBUG(llvm::dbgs() << "Func After transform:\n");
+ LLVM_DEBUG(func->dump());
LLVM_DEBUG(llvm::dbgs() << "=== End " DEBUG_TYPE " ===\n");
}