summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorJakub Kuderski <jakub@nod-labs.com>2023-12-12 13:15:17 -0500
committerGitHub <noreply@github.com>2023-12-12 13:15:17 -0500
commit8063622721d0b2b70e44f6e747eec54cdaec2e76 (patch)
treec41a28718fed07ff2f66f221e74147c95a2abffa /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent42e4967140e345923a43f809ba69be57200f46ae (diff)
[mlir][vector] Allow vector distribution with multiple written elements (#75122)
Add a configuration option to allow vector distribution with multiple elements written by a single lane. This is so that we can perform vector multi-reduction with multiple results per workgroup.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp8
1 files changed, 7 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 86b8d5f9b099..e593c0defcd2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -568,6 +568,11 @@ struct TestVectorDistribution
llvm::cl::desc("Test distribution of transfer write"),
llvm::cl::init(false)};
+ Option<unsigned> maxTransferWriteElements{
+ *this, "max-transfer-write-elements",
+ llvm::cl::desc("Maximum number of transfer write elements to distribute"),
+ llvm::cl::init(1)};
+
Option<bool> hoistUniform{*this, "hoist-uniform",
llvm::cl::desc("Test hoist uniform"),
llvm::cl::init(false)};
@@ -624,7 +629,8 @@ struct TestVectorDistribution
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (distributeTransferWriteOps) {
RewritePatternSet patterns(ctx);
- populateDistributeTransferWriteOpPatterns(patterns, distributionFn);
+ populateDistributeTransferWriteOpPatterns(patterns, distributionFn,
+ maxTransferWriteElements);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
} else if (propagateDistribution) {
RewritePatternSet patterns(ctx);