summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorQuinn Dawkins <quinn.dawkins@gmail.com>2023-11-10 05:49:33 -0800
committerGitHub <noreply@github.com>2023-11-10 08:49:33 -0500
commitdf49a97ab2e952d1588d9e7784987d982ebf3365 (patch)
treeac69d5220808128e5be7c751af711774584e175e /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parentd96ea279734fd9105a0ed7bad898ed84d79ed308 (diff)
[mlir][vector] Root the transfer write distribution pattern on the warp op (#71868)
Currently when there is a mix of transfer read ops and transfer write ops that need to be distributed, because the pattern for write distribution is rooted on the transfer write, it is hard to guarantee that the write gets distributed after the read when the two aren't directly connected by SSA. This is likely still relatively unsafe when there are undistributable ops, but structurally these patterns are a bit difficult to work with. For now pattern benefits give fairly good guarantees for happy paths.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp13
1 files changed, 10 insertions, 3 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 2fbf1babf437..1a177fa31de3 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -594,12 +594,19 @@ struct TestVectorDistribution
.getResult(0);
return result;
};
- if (distributeTransferWriteOps) {
+ if (distributeTransferWriteOps && propagateDistribution) {
+ RewritePatternSet patterns(ctx);
+ vector::populatePropagateWarpVectorDistributionPatterns(
+ patterns, distributionFn, shuffleFn, /*benefit=*/1,
+ /*readBenefit=*/0);
+ vector::populateDistributeReduction(patterns, warpReduction, 1);
+ populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ } else if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
- }
- if (propagateDistribution) {
+ } else if (propagateDistribution) {
RewritePatternSet patterns(ctx);
vector::populatePropagateWarpVectorDistributionPatterns(
patterns, distributionFn, shuffleFn);