summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Utils/Utils.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/Utils/Utils.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/Utils/Utils.cpp316
1 files changed, 254 insertions, 62 deletions
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
index 6658cca03eba..c0ee9d2afe91 100644
--- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp
+++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp
@@ -18,26 +18,23 @@
#include "mlir/Dialect/SCF/IR/SCF.h"
#include "mlir/IR/BuiltinOps.h"
#include "mlir/IR/IRMapping.h"
+#include "mlir/IR/OpDefinition.h"
#include "mlir/IR/PatternMatch.h"
#include "mlir/Interfaces/SideEffectInterfaces.h"
-#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/STLExtras.h"
#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/SmallPtrSet.h"
#include "llvm/ADT/SmallVector.h"
+#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
+#include <cstdint>
using namespace mlir;
-namespace {
-// This structure is to pass and return sets of loop parameters without
-// confusing the order.
-struct LoopParams {
- Value lowerBound;
- Value upperBound;
- Value step;
-};
-} // namespace
+#define DEBUG_TYPE "scf-utils"
+#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ")
+#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n")
SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields(
RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest,
@@ -296,6 +293,25 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend,
return builder.create<arith::DivUIOp>(loc, sum, divisor);
}
+/// Returns the trip count of `forOp` if its' low bound, high bound and step are
+/// constants, or optional otherwise. Trip count is computed as ceilDiv(highBound
+/// - lowBound, step).
+static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) {
+ std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
+ std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
+ std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
+ if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value())
+ return {};
+
+ // Constant loop bounds computation.
+ int64_t lbCst = lbCstOp.value();
+ int64_t ubCst = ubCstOp.value();
+ int64_t stepCst = stepCstOp.value();
+ assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 &&
+ "expected positive loop bounds and step");
+ return llvm::divideCeilSigned(ubCst - lbCst, stepCst);
+}
+
/// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with
/// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap
/// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each
@@ -372,25 +388,21 @@ LogicalResult mlir::loopUnrollByFactor(
Value stepUnrolled;
bool generateEpilogueLoop = true;
- std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound());
- std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound());
- std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep());
- if (lbCstOp && ubCstOp && stepCstOp) {
+ std::optional<int64_t> constTripCount = getConstantTripCount(forOp);
+ if (constTripCount) {
// Constant loop bounds computation.
- int64_t lbCst = lbCstOp.value();
- int64_t ubCst = ubCstOp.value();
- int64_t stepCst = stepCstOp.value();
- assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 &&
- "expected positive loop bounds and step");
- int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst);
-
+ int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value();
+ int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value();
+ int64_t stepCst = getConstantIntValue(forOp.getStep()).value();
if (unrollFactor == 1) {
- if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter)))
+ if (*constTripCount == 1 &&
+ failed(forOp.promoteIfSingleIteration(rewriter)))
return failure();
return success();
}
- int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor);
+ int64_t tripCountEvenMultiple =
+ *constTripCount - (*constTripCount % unrollFactor);
int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst;
int64_t stepUnrolledCst = stepCst * unrollFactor;
@@ -473,17 +485,188 @@ LogicalResult mlir::loopUnrollByFactor(
return success();
}
-/// Transform a loop with a strictly positive step
-/// for %i = %lb to %ub step %s
-/// into a 0-based loop with step 1
-/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 {
-/// %i = %ii * %s + %lb
-/// Insert the induction variable remapping in the body of `inner`, which is
-/// expected to be either `loop` or another loop perfectly nested under `loop`.
-/// Insert the definition of new bounds immediate before `outer`, which is
-/// expected to be either `loop` or its parent in the loop nest.
-static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
- Value lb, Value ub, Value step) {
+/// Check if bounds of all inner loops are defined outside of `forOp`
+/// and return false if not.
+static bool areInnerBoundsInvariant(scf::ForOp forOp) {
+ auto walkResult = forOp.walk([&](scf::ForOp innerForOp) {
+ if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) ||
+ !forOp.isDefinedOutsideOfLoop(innerForOp.getStep()))
+ return WalkResult::interrupt();
+
+ return WalkResult::advance();
+ });
+ return !walkResult.wasInterrupted();
+}
+
+/// Unrolls and jams this loop by the specified factor.
+LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp,
+ uint64_t unrollJamFactor) {
+ assert(unrollJamFactor > 0 && "unroll jam factor should be positive");
+
+ if (unrollJamFactor == 1)
+ return success();
+
+ // If any control operand of any inner loop of `forOp` is defined within
+ // `forOp`, no unroll jam.
+ if (!areInnerBoundsInvariant(forOp)) {
+ LDBG("failed to unroll and jam: inner bounds are not invariant");
+ return failure();
+ }
+
+ // Currently, for operations with results are not supported.
+ if (forOp->getNumResults() > 0) {
+ LDBG("failed to unroll and jam: unsupported loop with results");
+ return failure();
+ }
+
+ // Currently, only constant trip count that divided by the unroll factor is
+ // supported.
+ std::optional<uint64_t> tripCount = getConstantTripCount(forOp);
+ if (!tripCount.has_value()) {
+ // If the trip count is dynamic, do not unroll & jam.
+ LDBG("failed to unroll and jam: trip count could not be determined");
+ return failure();
+ }
+ if (unrollJamFactor > *tripCount) {
+ LDBG("unroll and jam factor is greater than trip count, set factor to trip "
+ "count");
+ unrollJamFactor = *tripCount;
+ } else if (*tripCount % unrollJamFactor != 0) {
+ LDBG("failed to unroll and jam: unsupported trip count that is not a "
+ "multiple of unroll jam factor");
+ return failure();
+ }
+
+ // Nothing in the loop body other than the terminator.
+ if (llvm::hasSingleElement(forOp.getBody()->getOperations()))
+ return success();
+
+ // Gather all sub-blocks to jam upon the loop being unrolled.
+ JamBlockGatherer<scf::ForOp> jbg;
+ jbg.walk(forOp);
+ auto &subBlocks = jbg.subBlocks;
+
+ // Collect inner loops.
+ SmallVector<scf::ForOp> innerLoops;
+ forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); });
+
+ // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled
+ // iteration. There are (`unrollJamFactor` - 1) iterations.
+ SmallVector<IRMapping> operandMaps(unrollJamFactor - 1);
+
+ // For any loop with iter_args, replace it with a new loop that has
+ // `unrollJamFactor` copies of its iterOperands, iter_args and yield
+ // operands.
+ SmallVector<scf::ForOp> newInnerLoops;
+ IRRewriter rewriter(forOp.getContext());
+ for (scf::ForOp oldForOp : innerLoops) {
+ SmallVector<Value> dupIterOperands, dupYieldOperands;
+ ValueRange oldIterOperands = oldForOp.getInits();
+ ValueRange oldIterArgs = oldForOp.getRegionIterArgs();
+ ValueRange oldYieldOperands =
+ cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands();
+ // Get additional iterOperands, iterArgs, and yield operands. We will
+ // fix iterOperands and yield operands after cloning of sub-blocks.
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end());
+ dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end());
+ }
+ // Create a new loop with additional iterOperands, iter_args and yield
+ // operands. This new loop will take the loop body of the original loop.
+ bool forOpReplaced = oldForOp == forOp;
+ scf::ForOp newForOp =
+ cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields(
+ rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false,
+ [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) {
+ return dupYieldOperands;
+ }));
+ newInnerLoops.push_back(newForOp);
+ // `forOp` has been replaced with a new loop.
+ if (forOpReplaced)
+ forOp = newForOp;
+ // Update `operandMaps` for `newForOp` iterArgs and results.
+ ValueRange newIterArgs = newForOp.getRegionIterArgs();
+ unsigned oldNumIterArgs = oldIterArgs.size();
+ ValueRange newResults = newForOp.getResults();
+ unsigned oldNumResults = newResults.size() / unrollJamFactor;
+ assert(oldNumIterArgs == oldNumResults &&
+ "oldNumIterArgs must be the same as oldNumResults");
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (unsigned j = 0; j < oldNumIterArgs; ++j) {
+ // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and
+ // results. Update `operandMaps[i - 1]` to map old iterArgs and results
+ // to those in the `i`th new set.
+ operandMaps[i - 1].map(newIterArgs[j],
+ newIterArgs[i * oldNumIterArgs + j]);
+ operandMaps[i - 1].map(newResults[j],
+ newResults[i * oldNumResults + j]);
+ }
+ }
+ }
+
+ // Scale the step of loop being unroll-jammed by the unroll-jam factor.
+ rewriter.setInsertionPoint(forOp);
+ int64_t step = forOp.getConstantStep()->getSExtValue();
+ auto newStep = rewriter.createOrFold<arith::MulIOp>(
+ forOp.getLoc(), forOp.getStep(),
+ rewriter.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor)));
+ forOp.setStep(newStep);
+ auto forOpIV = forOp.getInductionVar();
+
+ // Unroll and jam (appends unrollJamFactor - 1 additional copies).
+ for (unsigned i = unrollJamFactor - 1; i >= 1; --i) {
+ for (auto &subBlock : subBlocks) {
+ // Builder to insert unroll-jammed bodies. Insert right at the end of
+ // sub-block.
+ OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second));
+
+ // If the induction variable is used, create a remapping to the value for
+ // this unrolled instance.
+ if (!forOpIV.use_empty()) {
+ // iv' = iv + i * step, i = 1 to unrollJamFactor-1.
+ auto ivTag = builder.createOrFold<arith::ConstantOp>(
+ forOp.getLoc(), builder.getIndexAttr(step * i));
+ auto ivUnroll =
+ builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag);
+ operandMaps[i - 1].map(forOpIV, ivUnroll);
+ }
+ // Clone the sub-block being unroll-jammed.
+ for (auto it = subBlock.first; it != std::next(subBlock.second); ++it)
+ builder.clone(*it, operandMaps[i - 1]);
+ }
+ // Fix iterOperands and yield op operands of newly created loops.
+ for (auto newForOp : newInnerLoops) {
+ unsigned oldNumIterOperands =
+ newForOp.getNumRegionIterArgs() / unrollJamFactor;
+ unsigned numControlOperands = newForOp.getNumControlOperands();
+ auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator());
+ unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor;
+ assert(oldNumIterOperands == oldNumYieldOperands &&
+ "oldNumIterOperands must be the same as oldNumYieldOperands");
+ for (unsigned j = 0; j < oldNumIterOperands; ++j) {
+ // The `i`th duplication of an old iterOperand or yield op operand
+ // needs to be replaced with a mapped value from `operandMaps[i - 1]`
+ // if such mapped value exists.
+ newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j,
+ operandMaps[i - 1].lookupOrDefault(
+ newForOp.getOperand(numControlOperands + j)));
+ yieldOp.setOperand(
+ i * oldNumYieldOperands + j,
+ operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j)));
+ }
+ }
+ }
+
+ // Promote the loop body up if this has turned into a single iteration loop.
+ (void)forOp.promoteIfSingleIteration(rewriter);
+ return success();
+}
+
+Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
+ OpFoldResult lb, OpFoldResult ub,
+ OpFoldResult step) {
// For non-index types, generate `arith` instructions
// Check if the loop is already known to have a constant zero lower bound or
// a constant one step.
@@ -495,32 +678,38 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc,
if (auto stepCst = getConstantIntValue(step))
isStepOne = stepCst.value() == 1;
+ Type rangeType = getType(lb);
+ assert(rangeType == getType(ub) && rangeType == getType(step) &&
+ "expected matching types");
+
// Compute the number of iterations the loop executes: ceildiv(ub - lb, step)
// assuming the step is strictly positive. Update the bounds and the step
// of the loop to go from 0 to the number of iterations, if necessary.
if (isZeroBased && isStepOne)
return {lb, ub, step};
- Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb);
- Value newUpperBound =
- isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step);
+ OpFoldResult diff = ub;
+ if (!isZeroBased) {
+ diff = rewriter.createOrFold<arith::SubIOp>(
+ loc, getValueOrCreateConstantIntOp(rewriter, loc, ub),
+ getValueOrCreateConstantIntOp(rewriter, loc, lb));
+ }
+ OpFoldResult newUpperBound = diff;
+ if (!isStepOne) {
+ newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>(
+ loc, getValueOrCreateConstantIntOp(rewriter, loc, diff),
+ getValueOrCreateConstantIntOp(rewriter, loc, step));
+ }
- Value newLowerBound = isZeroBased
- ? lb
- : rewriter.create<arith::ConstantOp>(
- loc, rewriter.getZeroAttr(lb.getType()));
- Value newStep = isStepOne
- ? step
- : rewriter.create<arith::ConstantOp>(
- loc, rewriter.getIntegerAttr(step.getType(), 1));
+ OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType);
+ OpFoldResult newStep = rewriter.getOneAttr(rangeType);
return {newLowerBound, newUpperBound, newStep};
}
-/// Get back the original induction variable values after loop normalization
-static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
- Value normalizedIv, Value origLb,
- Value origStep) {
+void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
+ Value normalizedIv, OpFoldResult origLb,
+ OpFoldResult origStep) {
Value denormalizedIv;
SmallPtrSet<Operation *, 2> preserve;
bool isStepOne = isConstantIntValue(origStep, 1);
@@ -528,12 +717,15 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc,
Value scaled = normalizedIv;
if (!isStepOne) {
- scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep);
+ Value origStepValue =
+ getValueOrCreateConstantIntOp(rewriter, loc, origStep);
+ scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue);
preserve.insert(scaled.getDefiningOp());
}
denormalizedIv = scaled;
if (!isZeroBased) {
- denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb);
+ Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb);
+ denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue);
preserve.insert(denormalizedIv.getDefiningOp());
}
@@ -634,15 +826,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter,
Value lb = loop.getLowerBound();
Value ub = loop.getUpperBound();
Value step = loop.getStep();
- auto newLoopParams =
+ auto newLoopRange =
emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step);
rewriter.modifyOpInPlace(loop, [&]() {
- loop.setLowerBound(newLoopParams.lowerBound);
- loop.setUpperBound(newLoopParams.upperBound);
- loop.setStep(newLoopParams.step);
+ loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+ newLoopRange.offset));
+ loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+ newLoopRange.size));
+ loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(),
+ newLoopRange.stride));
});
-
rewriter.setInsertionPointToStart(innermost.getBody());
denormalizeInductionVariable(rewriter, loop.getLoc(),
loop.getInductionVar(), lb, step);
@@ -778,18 +972,16 @@ void mlir::collapseParallelLoops(
llvm::sort(dims);
// Normalize ParallelOp's iteration pattern.
- SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps,
- normalizedUpperBounds;
+ SmallVector<Value, 3> normalizedUpperBounds;
for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) {
OpBuilder::InsertionGuard g2(rewriter);
rewriter.setInsertionPoint(loops);
Value lb = loops.getLowerBound()[i];
Value ub = loops.getUpperBound()[i];
Value step = loops.getStep()[i];
- auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
- normalizedLowerBounds.push_back(newLoopParams.lowerBound);
- normalizedUpperBounds.push_back(newLoopParams.upperBound);
- normalizedSteps.push_back(newLoopParams.step);
+ auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step);
+ normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp(
+ rewriter, loops.getLoc(), newLoopRange.size));
rewriter.setInsertionPointToStart(loops.getBody());
denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb,