summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Transforms
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp4
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp55
2 files changed, 36 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
index 82ec95d31f52..cc1a22d0d48a 100644
--- a/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/LoopPipelining.cpp
@@ -17,10 +17,10 @@
#include "mlir/Dialect/SCF/Utils/Utils.h"
#include "mlir/IR/IRMapping.h"
#include "mlir/IR/PatternMatch.h"
-#include "mlir/Support/MathExtras.h"
#include "mlir/Transforms/RegionUtils.h"
#include "llvm/ADT/MapVector.h"
#include "llvm/Support/Debug.h"
+#include "llvm/Support/MathExtras.h"
#define DEBUG_TYPE "scf-loop-pipelining"
#define DBGS() (llvm::dbgs() << "[" DEBUG_TYPE "]: ")
@@ -119,7 +119,7 @@ bool LoopPipelinerInternal::initializeLoopInfo(
int64_t ubImm = upperBoundCst.value();
int64_t lbImm = lowerBoundCst.value();
int64_t stepImm = stepCst.value();
- int64_t numIteration = ceilDiv(ubImm - lbImm, stepImm);
+ int64_t numIteration = llvm::divideCeilSigned(ubImm - lbImm, stepImm);
if (numIteration > maxStage) {
dynamicLoop = false;
} else if (!options.supportDynamicLoops) {
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index f3d6b7a53011..35edd490f72e 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
SmallVector<Value> &initTensors = maybeInitTensors.value();
// 3. Define the callback to use for generating the inner most tile loop body.
- Operation *parallelOp = nullptr;
+ SmallVector<Operation *> parallelTiledOps;
auto innerYieldTiledValuesFn =
[&](RewriterBase &rewriter, Location loc, ValueRange ivs,
ValueRange regionIterArgs, SmallVector<Value> &tiledResult,
@@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
}
// 4a. Clone the operation.
- auto clonedOp = cast<PartialReductionOpInterface>(
- cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+ {
+ auto clonedOp = cast<PartialReductionOpInterface>(
+ cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs));
+
+ // 4b. Tile the cloned operation.
+ FailureOr<TilingResult> partialTilingResult =
+ clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets,
+ sizes, reductionDims);
+ if (failed(partialTilingResult)) {
+ return failure();
+ }
+ std::swap(parallelTiledOps, partialTilingResult->tiledOps);
+ std::swap(tiledResult, partialTilingResult->tiledValues);
- // 4b. Tile the cloned operation.
- parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs,
- offsets, sizes, reductionDims);
- // 4c. Delete the cloned operation.
- b.eraseOp(clonedOp);
+ // 4c. Delete the cloned operation.
+ b.eraseOp(clonedOp);
+ }
- tiledResult.append(parallelOp->result_begin(), parallelOp->result_end());
// 4d. Compute the offsets and sizes needed to insert the result of the
// tiled value back into destination before yielding the destination.
- for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) {
+ for (auto result : tiledResult) {
SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0));
resultOffsets.emplace_back(std::move(outOffsets));
SmallVector<OpFoldResult> outSizes;
for (size_t i = 0; i < offsets.size(); i++) {
- outSizes.push_back(
- tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i));
+ outSizes.push_back(tensor::getMixedSize(b, loc, result, i));
}
resultSizes.emplace_back(std::move(outSizes));
}
@@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b,
// 5. Apply the merge reduction to combine all the partial values.
b.setInsertionPointAfter(*loops.begin());
- Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims);
- b.replaceOp(op, mergeOp->getResults());
-
- SCFReductionTilingResult results;
- results.initialValues = initTensors;
- results.loops = loops;
- results.parallelTiledOp = parallelOp;
- results.mergeOp = mergeOp;
- return results;
+ FailureOr<MergeResult> mergeResult =
+ op.mergeReductions(b, loc, replacements, reductionDims);
+ if (failed(mergeResult)) {
+ return failure();
+ }
+ b.replaceOp(op, mergeResult->replacements);
+
+ SCFReductionTilingResult reductionTilingResult;
+ std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps);
+ std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps);
+ std::swap(reductionTilingResult.initialValues, initTensors);
+ std::swap(reductionTilingResult.loops, loops);
+ std::swap(reductionTilingResult.replacements, mergeResult->replacements);
+
+ return reductionTilingResult;
}
//===----------------------------------------------------------------------===//