summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
diff options
context:
space:
mode:
authorKunwar Grover <groverkss@gmail.com>2024-12-27 16:52:34 +0000
committerGitHub <noreply@github.com>2024-12-27 16:52:34 +0000
commit91bbebc7e118cceae1fc0e349de08094a3cd2fe7 (patch)
treefeb95f401fdcd308cae8fd5500f66c65fd1a0f27 /mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
parent07ba4575250b692b28d0fd5105e028b9f4c8e07f (diff)
[mlir][scf] Add getPartialResultTilePosition to PartialReductionOpInterface (#120465)
This PR adds a new interface method to PartialReductionOpInterface which allows it to query the result tile position for the partial result. Previously, tiling the reduction dimension with SplitReductionOuterReduction when the result has transposed parallel dimensions would produce wrong results. Other fixes that were needed to make this PR work: - Instead of ad-hoc logic to decide where to place the new reduction dimensions in the partial result based on the iteration space, the reduction dimensions are always appended to the partial result tensor. - Remove usage of PartialReductionOpInterface in Mesh dialect. The implementation was trying to just get a neutral element, but ended up trying to use PartialReductionOpInterface for it, which is not right. It was also passing the wrong sizes to it.
Diffstat (limited to 'mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp')
-rw-r--r--mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp28
1 files changed, 18 insertions, 10 deletions
diff --git a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
index 2277989bf841..b548f8ce8b56 100644
--- a/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
+++ b/mlir/lib/Dialect/SCF/Transforms/TileUsingInterface.cpp
@@ -657,21 +657,29 @@ getResultTilePosition(RewriterBase &rewriter, int64_t index, Value tiledResult,
resultOffset, resultSize);
case scf::SCFTilingOptions::ReductionTilingStrategy::
PartialReductionOuterReduction: {
- // TODO: This does not work for non identity accesses to the result tile.
- // The proper fix is to add a getPartialResultTilePosition method to
- // PartialReductionOpInterface.
- resultOffset =
- SmallVector<OpFoldResult>(offsets.size(), rewriter.getIndexAttr(0));
- for (size_t i = 0; i < offsets.size(); i++) {
- resultSize.push_back(
- tensor::getMixedSize(rewriter, op.getLoc(), tiledResult, i));
+ auto redOp = dyn_cast<PartialReductionOpInterface>(op.getOperation());
+ if (!redOp) {
+ return rewriter.notifyMatchFailure(
+ op, "PartialReductionOuterReduction tiling strategy is only supported"
+ "for operations implementing PartialReductionOpInterface");
}
- return success();
+ // Get reduction dimensions.
+ // TODO: PartialReductionOpInterface should really query TilingInterface
+ // itself and find reduction dimensions.
+ SmallVector<int> reductionDims;
+ for (auto [idx, iteratorType] :
+ llvm::enumerate(op.getLoopIteratorTypes())) {
+ if (iteratorType == utils::IteratorType::reduction)
+ reductionDims.push_back(idx);
+ }
+ return redOp.getPartialResultTilePosition(rewriter, index, offsets, sizes,
+ resultOffset, resultSize,
+ reductionDims);
+ }
default:
return rewriter.notifyMatchFailure(op,
"unhandled reduction tiling strategy");
}
- }
}
static FailureOr<MergeResult>