diff options
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp')
| -rw-r--r-- | mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp | 55 |
1 files changed, 34 insertions, 21 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp index f3d6b7a53011..35edd490f72e 100644 --- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp +++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp @@ -718,7 +718,7 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, SmallVector<Value> &initTensors = maybeInitTensors.value(); // 3. Define the callback to use for generating the inner most tile loop body. - Operation *parallelOp = nullptr; + SmallVector<Operation *> parallelTiledOps; auto innerYieldTiledValuesFn = [&](RewriterBase &rewriter, Location loc, ValueRange ivs, ValueRange regionIterArgs, SmallVector<Value> &tiledResult, @@ -743,26 +743,33 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, } // 4a. Clone the operation. - auto clonedOp = cast<PartialReductionOpInterface>( - cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + { + auto clonedOp = cast<PartialReductionOpInterface>( + cloneOpAndUpdateDestinationArgs(b, op, regionIterArgs)); + + // 4b. Tile the cloned operation. + FailureOr<TilingResult> partialTilingResult = + clonedOp.tileToPartialReduction(b, loc, regionIterArgs, offsets, + sizes, reductionDims); + if (failed(partialTilingResult)) { + return failure(); + } + std::swap(parallelTiledOps, partialTilingResult->tiledOps); + std::swap(tiledResult, partialTilingResult->tiledValues); - // 4b. Tile the cloned operation. - parallelOp = clonedOp.tileToPartialReduction(b, loc, regionIterArgs, - offsets, sizes, reductionDims); - // 4c. Delete the cloned operation. - b.eraseOp(clonedOp); + // 4c. Delete the cloned operation. + b.eraseOp(clonedOp); + } - tiledResult.append(parallelOp->result_begin(), parallelOp->result_end()); // 4d. Compute the offsets and sizes needed to insert the result of the // tiled value back into destination before yielding the destination. - for (int resultIdx : llvm::seq<int>(0, parallelOp->getNumResults())) { + for (auto result : tiledResult) { SmallVector<OpFoldResult> outOffsets(offsets.size(), b.getIndexAttr(0)); resultOffsets.emplace_back(std::move(outOffsets)); SmallVector<OpFoldResult> outSizes; for (size_t i = 0; i < offsets.size(); i++) { - outSizes.push_back( - tensor::getMixedSize(b, loc, parallelOp->getResult(resultIdx), i)); + outSizes.push_back(tensor::getMixedSize(b, loc, result, i)); } resultSizes.emplace_back(std::move(outSizes)); } @@ -782,15 +789,21 @@ mlir::scf::tileReductionUsingScf(RewriterBase &b, // 5. Apply the merge reduction to combine all the partial values. b.setInsertionPointAfter(*loops.begin()); - Operation *mergeOp = op.mergeReductions(b, loc, replacements, reductionDims); - b.replaceOp(op, mergeOp->getResults()); - - SCFReductionTilingResult results; - results.initialValues = initTensors; - results.loops = loops; - results.parallelTiledOp = parallelOp; - results.mergeOp = mergeOp; - return results; + FailureOr<MergeResult> mergeResult = + op.mergeReductions(b, loc, replacements, reductionDims); + if (failed(mergeResult)) { + return failure(); + } + b.replaceOp(op, mergeResult->replacements); + + SCFReductionTilingResult reductionTilingResult; + std::swap(reductionTilingResult.parallelTiledOps, parallelTiledOps); + std::swap(reductionTilingResult.mergeOps, mergeResult->mergeOps); + std::swap(reductionTilingResult.initialValues, initTensors); + std::swap(reductionTilingResult.loops, loops); + std::swap(reductionTilingResult.replacements, mergeResult->replacements); + + return reductionTilingResult; } //===----------------------------------------------------------------------===// |
