diff options
| author | Jakub Kuderski <jakub@nod-labs.com> | 2023-12-12 13:15:17 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-12-12 13:15:17 -0500 |
| commit | 8063622721d0b2b70e44f6e747eec54cdaec2e76 (patch) | |
| tree | c41a28718fed07ff2f66f221e74147c95a2abffa /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 42e4967140e345923a43f809ba69be57200f46ae (diff) | |
[mlir][vector] Allow vector distribution with multiple written elements (#75122)
Add a configuration option to allow vector distribution with multiple
elements written by a single lane.
This is so that we can perform vector multi-reduction with multiple
results per workgroup.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 8 |
1 files changed, 7 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 86b8d5f9b099..e593c0defcd2 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -568,6 +568,11 @@ struct TestVectorDistribution llvm::cl::desc("Test distribution of transfer write"), llvm::cl::init(false)}; + Option<unsigned> maxTransferWriteElements{ + *this, "max-transfer-write-elements", + llvm::cl::desc("Maximum number of transfer write elements to distribute"), + llvm::cl::init(1)}; + Option<bool> hoistUniform{*this, "hoist-uniform", llvm::cl::desc("Test hoist uniform"), llvm::cl::init(false)}; @@ -624,7 +629,8 @@ struct TestVectorDistribution (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } else if (distributeTransferWriteOps) { RewritePatternSet patterns(ctx); - populateDistributeTransferWriteOpPatterns(patterns, distributionFn); + populateDistributeTransferWriteOpPatterns(patterns, distributionFn, + maxTransferWriteElements); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } else if (propagateDistribution) { RewritePatternSet patterns(ctx); |
