summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorIvan Butygin <ivan.butygin@gmail.com>2024-02-13 15:30:58 +0300
committerGitHub <noreply@github.com>2024-02-13 15:30:58 +0300
commit35ef3994bf738318b59ce640910fb1ccd3bb7dcb (patch)
tree1ccb1ff0f15ba0b964a53bbd3c472975eaaefd69 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parentbfc0b7c6891896ee8e9818f22800472510093864 (diff)
[mlir][vector] ND vectors linearization pass (#81159)
Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V conversion doesn't handle them at all at the moment. Sometimes it's preferable to treat multidim vectors as linearized 1D. Add pass to do this. Only constants and simple elementwise ops are supported for now. @krzysz00 I've extracted yours result type conversion code from LegalizeToF32 and moved it to common place. Also, add ConversionPattern class operating on traits.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 126d65b1b848..acd38980514a 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestVectorLinearize final
+ : public PassWrapper<TestVectorLinearize, OperationPass<>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize)
+
+ StringRef getArgument() const override { return "test-vector-linearize"; }
+ StringRef getDescription() const override {
+ return "Linearizes ND vectors for N >= 2 into 1D vectors";
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry.insert<vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ auto *context = &getContext();
+
+ TypeConverter typeConverter;
+ RewritePatternSet patterns(context);
+ ConversionTarget target(*context);
+
+ vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter,
+ patterns, target);
+ if (failed(applyPartialConversion(getOperation(), target,
+ std::move(patterns))))
+ return signalPassFailure();
+ }
+};
} // namespace
namespace mlir {
@@ -867,6 +894,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
PassRegistration<TestVectorEmulateMaskedLoadStore>();
+
+ PassRegistration<TestVectorLinearize>();
}
} // namespace test
} // namespace mlir