summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorDiego Caballero <diegocaballero@google.com>2024-02-21 09:22:48 -0800
committerGitHub <noreply@github.com>2024-02-21 09:22:48 -0800
commit71441ed1716e6ed3f053dea9c1ceb9cfe2822aea (patch)
treebef7a23b1ea244a8cdb30df8cec4fcb2ac92bf15 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent162fa4dd25d631d0ab7816ec6081bcaff951a23c (diff)
[mlir][Vector] Add vector bitwidth target to xfer op flattening (#81966)
This PR adds an optional bitwidth parameter to the vector xfer op flattening transformation so that the flattening doesn't happen if the trailing dimension of the read/writen vector is larger than this bitwidth (i.e., we are already able to fill at least one vector register with that size).
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp16
1 files changed, 15 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index acd38980514a..178a58e796b2 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -466,21 +466,35 @@ struct TestFlattenVectorTransferPatterns
MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(
TestFlattenVectorTransferPatterns)
+ TestFlattenVectorTransferPatterns() = default;
+ TestFlattenVectorTransferPatterns(
+ const TestFlattenVectorTransferPatterns &pass)
+ : PassWrapper(pass) {}
+
StringRef getArgument() const final {
return "test-vector-transfer-flatten-patterns";
}
+
StringRef getDescription() const final {
return "Test patterns to rewrite contiguous row-major N-dimensional "
"vector.transfer_{read,write} ops into 1D transfers";
}
+
void getDependentDialects(DialectRegistry &registry) const override {
registry.insert<memref::MemRefDialect>();
registry.insert<affine::AffineDialect>();
registry.insert<vector::VectorDialect>();
}
+
+ Option<unsigned> targetVectorBitwidth{
+ *this, "target-vector-bitwidth",
+ llvm::cl::desc(
+ "Minimum vector bitwidth to enable the flattening transformation"),
+ llvm::cl::init(std::numeric_limits<unsigned>::max())};
+
void runOnOperation() override {
RewritePatternSet patterns(&getContext());
- populateFlattenVectorTransferPatterns(patterns);
+ populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth);
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};