diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/Utils/Utils.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/Utils/Utils.cpp | 316 |
1 files changed, 254 insertions, 62 deletions
diff --git a/mlir/lib/Dialect/SCF/Utils/Utils.cpp b/mlir/lib/Dialect/SCF/Utils/Utils.cpp index 6658cca03eba..c0ee9d2afe91 100644 --- a/mlir/lib/Dialect/SCF/Utils/Utils.cpp +++ b/mlir/lib/Dialect/SCF/Utils/Utils.cpp @@ -18,26 +18,23 @@ #include "mlir/Dialect/SCF/IR/SCF.h" #include "mlir/IR/BuiltinOps.h" #include "mlir/IR/IRMapping.h" +#include "mlir/IR/OpDefinition.h" #include "mlir/IR/PatternMatch.h" #include "mlir/Interfaces/SideEffectInterfaces.h" -#include "mlir/Support/MathExtras.h" #include "mlir/Transforms/RegionUtils.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/SetVector.h" #include "llvm/ADT/SmallPtrSet.h" #include "llvm/ADT/SmallVector.h" +#include "llvm/Support/Debug.h" +#include "llvm/Support/MathExtras.h" +#include <cstdint> using namespace mlir; -namespace { -// This structure is to pass and return sets of loop parameters without -// confusing the order. -struct LoopParams { - Value lowerBound; - Value upperBound; - Value step; -}; -} // namespace +#define DEBUG_TYPE "scf-utils" +#define DBGS() (llvm::dbgs() << '[' << DEBUG_TYPE << "] ") +#define LDBG(X) LLVM_DEBUG(DBGS() << X << "\n") SmallVector<scf::ForOp> mlir::replaceLoopNestWithNewYields( RewriterBase &rewriter, MutableArrayRef<scf::ForOp> loopNest, @@ -296,6 +293,25 @@ static Value ceilDivPositive(OpBuilder &builder, Location loc, Value dividend, return builder.create<arith::DivUIOp>(loc, sum, divisor); } +/// Returns the trip count of `forOp` if its' low bound, high bound and step are +/// constants, or optional otherwise. Trip count is computed as ceilDiv(highBound +/// - lowBound, step). +static std::optional<int64_t> getConstantTripCount(scf::ForOp forOp) { + std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); + std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); + std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); + if (!lbCstOp.has_value() || !ubCstOp.has_value() || !stepCstOp.has_value()) + return {}; + + // Constant loop bounds computation. + int64_t lbCst = lbCstOp.value(); + int64_t ubCst = ubCstOp.value(); + int64_t stepCst = stepCstOp.value(); + assert(lbCst >= 0 && ubCst >= 0 && stepCst > 0 && + "expected positive loop bounds and step"); + return llvm::divideCeilSigned(ubCst - lbCst, stepCst); +} + /// Generates unrolled copies of scf::ForOp 'loopBodyBlock', with /// associated 'forOpIV' by 'unrollFactor', calling 'ivRemapFn' to remap /// 'forOpIV' for each unrolled body. If specified, annotates the Ops in each @@ -372,25 +388,21 @@ LogicalResult mlir::loopUnrollByFactor( Value stepUnrolled; bool generateEpilogueLoop = true; - std::optional<int64_t> lbCstOp = getConstantIntValue(forOp.getLowerBound()); - std::optional<int64_t> ubCstOp = getConstantIntValue(forOp.getUpperBound()); - std::optional<int64_t> stepCstOp = getConstantIntValue(forOp.getStep()); - if (lbCstOp && ubCstOp && stepCstOp) { + std::optional<int64_t> constTripCount = getConstantTripCount(forOp); + if (constTripCount) { // Constant loop bounds computation. - int64_t lbCst = lbCstOp.value(); - int64_t ubCst = ubCstOp.value(); - int64_t stepCst = stepCstOp.value(); - assert(lbCst >= 0 && ubCst >= 0 && stepCst >= 0 && - "expected positive loop bounds and step"); - int64_t tripCount = mlir::ceilDiv(ubCst - lbCst, stepCst); - + int64_t lbCst = getConstantIntValue(forOp.getLowerBound()).value(); + int64_t ubCst = getConstantIntValue(forOp.getUpperBound()).value(); + int64_t stepCst = getConstantIntValue(forOp.getStep()).value(); if (unrollFactor == 1) { - if (tripCount == 1 && failed(forOp.promoteIfSingleIteration(rewriter))) + if (*constTripCount == 1 && + failed(forOp.promoteIfSingleIteration(rewriter))) return failure(); return success(); } - int64_t tripCountEvenMultiple = tripCount - (tripCount % unrollFactor); + int64_t tripCountEvenMultiple = + *constTripCount - (*constTripCount % unrollFactor); int64_t upperBoundUnrolledCst = lbCst + tripCountEvenMultiple * stepCst; int64_t stepUnrolledCst = stepCst * unrollFactor; @@ -473,17 +485,188 @@ LogicalResult mlir::loopUnrollByFactor( return success(); } -/// Transform a loop with a strictly positive step -/// for %i = %lb to %ub step %s -/// into a 0-based loop with step 1 -/// for %ii = 0 to ceildiv(%ub - %lb, %s) step 1 { -/// %i = %ii * %s + %lb -/// Insert the induction variable remapping in the body of `inner`, which is -/// expected to be either `loop` or another loop perfectly nested under `loop`. -/// Insert the definition of new bounds immediate before `outer`, which is -/// expected to be either `loop` or its parent in the loop nest. -static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, - Value lb, Value ub, Value step) { +/// Check if bounds of all inner loops are defined outside of `forOp` +/// and return false if not. +static bool areInnerBoundsInvariant(scf::ForOp forOp) { + auto walkResult = forOp.walk([&](scf::ForOp innerForOp) { + if (!forOp.isDefinedOutsideOfLoop(innerForOp.getLowerBound()) || + !forOp.isDefinedOutsideOfLoop(innerForOp.getUpperBound()) || + !forOp.isDefinedOutsideOfLoop(innerForOp.getStep())) + return WalkResult::interrupt(); + + return WalkResult::advance(); + }); + return !walkResult.wasInterrupted(); +} + +/// Unrolls and jams this loop by the specified factor. +LogicalResult mlir::loopUnrollJamByFactor(scf::ForOp forOp, + uint64_t unrollJamFactor) { + assert(unrollJamFactor > 0 && "unroll jam factor should be positive"); + + if (unrollJamFactor == 1) + return success(); + + // If any control operand of any inner loop of `forOp` is defined within + // `forOp`, no unroll jam. + if (!areInnerBoundsInvariant(forOp)) { + LDBG("failed to unroll and jam: inner bounds are not invariant"); + return failure(); + } + + // Currently, for operations with results are not supported. + if (forOp->getNumResults() > 0) { + LDBG("failed to unroll and jam: unsupported loop with results"); + return failure(); + } + + // Currently, only constant trip count that divided by the unroll factor is + // supported. + std::optional<uint64_t> tripCount = getConstantTripCount(forOp); + if (!tripCount.has_value()) { + // If the trip count is dynamic, do not unroll & jam. + LDBG("failed to unroll and jam: trip count could not be determined"); + return failure(); + } + if (unrollJamFactor > *tripCount) { + LDBG("unroll and jam factor is greater than trip count, set factor to trip " + "count"); + unrollJamFactor = *tripCount; + } else if (*tripCount % unrollJamFactor != 0) { + LDBG("failed to unroll and jam: unsupported trip count that is not a " + "multiple of unroll jam factor"); + return failure(); + } + + // Nothing in the loop body other than the terminator. + if (llvm::hasSingleElement(forOp.getBody()->getOperations())) + return success(); + + // Gather all sub-blocks to jam upon the loop being unrolled. + JamBlockGatherer<scf::ForOp> jbg; + jbg.walk(forOp); + auto &subBlocks = jbg.subBlocks; + + // Collect inner loops. + SmallVector<scf::ForOp> innerLoops; + forOp.walk([&](scf::ForOp innerForOp) { innerLoops.push_back(innerForOp); }); + + // `operandMaps[i - 1]` carries old->new operand mapping for the ith unrolled + // iteration. There are (`unrollJamFactor` - 1) iterations. + SmallVector<IRMapping> operandMaps(unrollJamFactor - 1); + + // For any loop with iter_args, replace it with a new loop that has + // `unrollJamFactor` copies of its iterOperands, iter_args and yield + // operands. + SmallVector<scf::ForOp> newInnerLoops; + IRRewriter rewriter(forOp.getContext()); + for (scf::ForOp oldForOp : innerLoops) { + SmallVector<Value> dupIterOperands, dupYieldOperands; + ValueRange oldIterOperands = oldForOp.getInits(); + ValueRange oldIterArgs = oldForOp.getRegionIterArgs(); + ValueRange oldYieldOperands = + cast<scf::YieldOp>(oldForOp.getBody()->getTerminator()).getOperands(); + // Get additional iterOperands, iterArgs, and yield operands. We will + // fix iterOperands and yield operands after cloning of sub-blocks. + for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { + dupIterOperands.append(oldIterOperands.begin(), oldIterOperands.end()); + dupYieldOperands.append(oldYieldOperands.begin(), oldYieldOperands.end()); + } + // Create a new loop with additional iterOperands, iter_args and yield + // operands. This new loop will take the loop body of the original loop. + bool forOpReplaced = oldForOp == forOp; + scf::ForOp newForOp = + cast<scf::ForOp>(*oldForOp.replaceWithAdditionalYields( + rewriter, dupIterOperands, /*replaceInitOperandUsesInLoop=*/false, + [&](OpBuilder &b, Location loc, ArrayRef<BlockArgument> newBbArgs) { + return dupYieldOperands; + })); + newInnerLoops.push_back(newForOp); + // `forOp` has been replaced with a new loop. + if (forOpReplaced) + forOp = newForOp; + // Update `operandMaps` for `newForOp` iterArgs and results. + ValueRange newIterArgs = newForOp.getRegionIterArgs(); + unsigned oldNumIterArgs = oldIterArgs.size(); + ValueRange newResults = newForOp.getResults(); + unsigned oldNumResults = newResults.size() / unrollJamFactor; + assert(oldNumIterArgs == oldNumResults && + "oldNumIterArgs must be the same as oldNumResults"); + for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { + for (unsigned j = 0; j < oldNumIterArgs; ++j) { + // `newForOp` has `unrollJamFactor` - 1 new sets of iterArgs and + // results. Update `operandMaps[i - 1]` to map old iterArgs and results + // to those in the `i`th new set. + operandMaps[i - 1].map(newIterArgs[j], + newIterArgs[i * oldNumIterArgs + j]); + operandMaps[i - 1].map(newResults[j], + newResults[i * oldNumResults + j]); + } + } + } + + // Scale the step of loop being unroll-jammed by the unroll-jam factor. + rewriter.setInsertionPoint(forOp); + int64_t step = forOp.getConstantStep()->getSExtValue(); + auto newStep = rewriter.createOrFold<arith::MulIOp>( + forOp.getLoc(), forOp.getStep(), + rewriter.createOrFold<arith::ConstantOp>( + forOp.getLoc(), rewriter.getIndexAttr(unrollJamFactor))); + forOp.setStep(newStep); + auto forOpIV = forOp.getInductionVar(); + + // Unroll and jam (appends unrollJamFactor - 1 additional copies). + for (unsigned i = unrollJamFactor - 1; i >= 1; --i) { + for (auto &subBlock : subBlocks) { + // Builder to insert unroll-jammed bodies. Insert right at the end of + // sub-block. + OpBuilder builder(subBlock.first->getBlock(), std::next(subBlock.second)); + + // If the induction variable is used, create a remapping to the value for + // this unrolled instance. + if (!forOpIV.use_empty()) { + // iv' = iv + i * step, i = 1 to unrollJamFactor-1. + auto ivTag = builder.createOrFold<arith::ConstantOp>( + forOp.getLoc(), builder.getIndexAttr(step * i)); + auto ivUnroll = + builder.createOrFold<arith::AddIOp>(forOp.getLoc(), forOpIV, ivTag); + operandMaps[i - 1].map(forOpIV, ivUnroll); + } + // Clone the sub-block being unroll-jammed. + for (auto it = subBlock.first; it != std::next(subBlock.second); ++it) + builder.clone(*it, operandMaps[i - 1]); + } + // Fix iterOperands and yield op operands of newly created loops. + for (auto newForOp : newInnerLoops) { + unsigned oldNumIterOperands = + newForOp.getNumRegionIterArgs() / unrollJamFactor; + unsigned numControlOperands = newForOp.getNumControlOperands(); + auto yieldOp = cast<scf::YieldOp>(newForOp.getBody()->getTerminator()); + unsigned oldNumYieldOperands = yieldOp.getNumOperands() / unrollJamFactor; + assert(oldNumIterOperands == oldNumYieldOperands && + "oldNumIterOperands must be the same as oldNumYieldOperands"); + for (unsigned j = 0; j < oldNumIterOperands; ++j) { + // The `i`th duplication of an old iterOperand or yield op operand + // needs to be replaced with a mapped value from `operandMaps[i - 1]` + // if such mapped value exists. + newForOp.setOperand(numControlOperands + i * oldNumIterOperands + j, + operandMaps[i - 1].lookupOrDefault( + newForOp.getOperand(numControlOperands + j))); + yieldOp.setOperand( + i * oldNumYieldOperands + j, + operandMaps[i - 1].lookupOrDefault(yieldOp.getOperand(j))); + } + } + } + + // Promote the loop body up if this has turned into a single iteration loop. + (void)forOp.promoteIfSingleIteration(rewriter); + return success(); +} + +Range mlir::emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, + OpFoldResult lb, OpFoldResult ub, + OpFoldResult step) { // For non-index types, generate `arith` instructions // Check if the loop is already known to have a constant zero lower bound or // a constant one step. @@ -495,32 +678,38 @@ static LoopParams emitNormalizedLoopBounds(RewriterBase &rewriter, Location loc, if (auto stepCst = getConstantIntValue(step)) isStepOne = stepCst.value() == 1; + Type rangeType = getType(lb); + assert(rangeType == getType(ub) && rangeType == getType(step) && + "expected matching types"); + // Compute the number of iterations the loop executes: ceildiv(ub - lb, step) // assuming the step is strictly positive. Update the bounds and the step // of the loop to go from 0 to the number of iterations, if necessary. if (isZeroBased && isStepOne) return {lb, ub, step}; - Value diff = isZeroBased ? ub : rewriter.create<arith::SubIOp>(loc, ub, lb); - Value newUpperBound = - isStepOne ? diff : rewriter.create<arith::CeilDivSIOp>(loc, diff, step); + OpFoldResult diff = ub; + if (!isZeroBased) { + diff = rewriter.createOrFold<arith::SubIOp>( + loc, getValueOrCreateConstantIntOp(rewriter, loc, ub), + getValueOrCreateConstantIntOp(rewriter, loc, lb)); + } + OpFoldResult newUpperBound = diff; + if (!isStepOne) { + newUpperBound = rewriter.createOrFold<arith::CeilDivSIOp>( + loc, getValueOrCreateConstantIntOp(rewriter, loc, diff), + getValueOrCreateConstantIntOp(rewriter, loc, step)); + } - Value newLowerBound = isZeroBased - ? lb - : rewriter.create<arith::ConstantOp>( - loc, rewriter.getZeroAttr(lb.getType())); - Value newStep = isStepOne - ? step - : rewriter.create<arith::ConstantOp>( - loc, rewriter.getIntegerAttr(step.getType(), 1)); + OpFoldResult newLowerBound = rewriter.getZeroAttr(rangeType); + OpFoldResult newStep = rewriter.getOneAttr(rangeType); return {newLowerBound, newUpperBound, newStep}; } -/// Get back the original induction variable values after loop normalization -static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, - Value normalizedIv, Value origLb, - Value origStep) { +void mlir::denormalizeInductionVariable(RewriterBase &rewriter, Location loc, + Value normalizedIv, OpFoldResult origLb, + OpFoldResult origStep) { Value denormalizedIv; SmallPtrSet<Operation *, 2> preserve; bool isStepOne = isConstantIntValue(origStep, 1); @@ -528,12 +717,15 @@ static void denormalizeInductionVariable(RewriterBase &rewriter, Location loc, Value scaled = normalizedIv; if (!isStepOne) { - scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStep); + Value origStepValue = + getValueOrCreateConstantIntOp(rewriter, loc, origStep); + scaled = rewriter.create<arith::MulIOp>(loc, normalizedIv, origStepValue); preserve.insert(scaled.getDefiningOp()); } denormalizedIv = scaled; if (!isZeroBased) { - denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLb); + Value origLbValue = getValueOrCreateConstantIntOp(rewriter, loc, origLb); + denormalizedIv = rewriter.create<arith::AddIOp>(loc, scaled, origLbValue); preserve.insert(denormalizedIv.getDefiningOp()); } @@ -634,15 +826,17 @@ LogicalResult mlir::coalesceLoops(RewriterBase &rewriter, Value lb = loop.getLowerBound(); Value ub = loop.getUpperBound(); Value step = loop.getStep(); - auto newLoopParams = + auto newLoopRange = emitNormalizedLoopBounds(rewriter, loop.getLoc(), lb, ub, step); rewriter.modifyOpInPlace(loop, [&]() { - loop.setLowerBound(newLoopParams.lowerBound); - loop.setUpperBound(newLoopParams.upperBound); - loop.setStep(newLoopParams.step); + loop.setLowerBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), + newLoopRange.offset)); + loop.setUpperBound(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), + newLoopRange.size)); + loop.setStep(getValueOrCreateConstantIntOp(rewriter, loop.getLoc(), + newLoopRange.stride)); }); - rewriter.setInsertionPointToStart(innermost.getBody()); denormalizeInductionVariable(rewriter, loop.getLoc(), loop.getInductionVar(), lb, step); @@ -778,18 +972,16 @@ void mlir::collapseParallelLoops( llvm::sort(dims); // Normalize ParallelOp's iteration pattern. - SmallVector<Value, 3> normalizedLowerBounds, normalizedSteps, - normalizedUpperBounds; + SmallVector<Value, 3> normalizedUpperBounds; for (unsigned i = 0, e = loops.getNumLoops(); i < e; ++i) { OpBuilder::InsertionGuard g2(rewriter); rewriter.setInsertionPoint(loops); Value lb = loops.getLowerBound()[i]; Value ub = loops.getUpperBound()[i]; Value step = loops.getStep()[i]; - auto newLoopParams = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); - normalizedLowerBounds.push_back(newLoopParams.lowerBound); - normalizedUpperBounds.push_back(newLoopParams.upperBound); - normalizedSteps.push_back(newLoopParams.step); + auto newLoopRange = emitNormalizedLoopBounds(rewriter, loc, lb, ub, step); + normalizedUpperBounds.push_back(getValueOrCreateConstantIntOp( + rewriter, loops.getLoc(), newLoopRange.size)); rewriter.setInsertionPointToStart(loops.getBody()); denormalizeInductionVariable(rewriter, loc, loops.getInductionVars()[i], lb, |
