diff options
| author | Quinn Dawkins <quinn.dawkins@gmail.com> | 2023-11-10 05:49:33 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-11-10 08:49:33 -0500 |
| commit | df49a97ab2e952d1588d9e7784987d982ebf3365 (patch) | |
| tree | ac69d5220808128e5be7c751af711774584e175e /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | d96ea279734fd9105a0ed7bad898ed84d79ed308 (diff) | |
[mlir][vector] Root the transfer write distribution pattern on the warp op (#71868)
Currently when there is a mix of transfer read ops and transfer write
ops that need to be distributed, because the pattern for write
distribution is rooted on the transfer write, it is hard to guarantee
that the write gets distributed after the read when the two aren't
directly connected by SSA. This is likely still relatively unsafe when
there are undistributable ops, but structurally these patterns are a bit
difficult to work with. For now pattern benefits give fairly good
guarantees for happy paths.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 13 |
1 files changed, 10 insertions, 3 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 2fbf1babf437..1a177fa31de3 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -594,12 +594,19 @@ struct TestVectorDistribution .getResult(0); return result; }; - if (distributeTransferWriteOps) { + if (distributeTransferWriteOps && propagateDistribution) { + RewritePatternSet patterns(ctx); + vector::populatePropagateWarpVectorDistributionPatterns( + patterns, distributionFn, shuffleFn, /*benefit=*/1, + /*readBenefit=*/0); + vector::populateDistributeReduction(patterns, warpReduction, 1); + populateDistributeTransferWriteOpPatterns(patterns, distributionFn, 2); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } else if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); populateDistributeTransferWriteOpPatterns(patterns, distributionFn); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); - } - if (propagateDistribution) { + } else if (propagateDistribution) { RewritePatternSet patterns(ctx); vector::populatePropagateWarpVectorDistributionPatterns( patterns, distributionFn, shuffleFn); |
