diff options
| author | Diego Caballero <diegocaballero@google.com> | 2024-02-21 09:22:48 -0800 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-21 09:22:48 -0800 |
| commit | 71441ed1716e6ed3f053dea9c1ceb9cfe2822aea (patch) | |
| tree | bef7a23b1ea244a8cdb30df8cec4fcb2ac92bf15 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 162fa4dd25d631d0ab7816ec6081bcaff951a23c (diff) | |
[mlir][Vector] Add vector bitwidth target to xfer op flattening (#81966)
This PR adds an optional bitwidth parameter to the vector xfer op
flattening transformation so that the flattening doesn't happen if the
trailing dimension of the read/writen vector is larger than this
bitwidth (i.e., we are already able to fill at least one vector register
with that size).
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 16 |
1 files changed, 15 insertions, 1 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index acd38980514a..178a58e796b2 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -466,21 +466,35 @@ struct TestFlattenVectorTransferPatterns MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID( TestFlattenVectorTransferPatterns) + TestFlattenVectorTransferPatterns() = default; + TestFlattenVectorTransferPatterns( + const TestFlattenVectorTransferPatterns &pass) + : PassWrapper(pass) {} + StringRef getArgument() const final { return "test-vector-transfer-flatten-patterns"; } + StringRef getDescription() const final { return "Test patterns to rewrite contiguous row-major N-dimensional " "vector.transfer_{read,write} ops into 1D transfers"; } + void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<memref::MemRefDialect>(); registry.insert<affine::AffineDialect>(); registry.insert<vector::VectorDialect>(); } + + Option<unsigned> targetVectorBitwidth{ + *this, "target-vector-bitwidth", + llvm::cl::desc( + "Minimum vector bitwidth to enable the flattening transformation"), + llvm::cl::init(std::numeric_limits<unsigned>::max())}; + void runOnOperation() override { RewritePatternSet patterns(&getContext()); - populateFlattenVectorTransferPatterns(patterns); + populateFlattenVectorTransferPatterns(patterns, targetVectorBitwidth); (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; |
