diff options
| author | Ivan Butygin <ivan.butygin@gmail.com> | 2024-02-13 15:30:58 +0300 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-13 15:30:58 +0300 |
| commit | 35ef3994bf738318b59ce640910fb1ccd3bb7dcb (patch) | |
| tree | 1ccb1ff0f15ba0b964a53bbd3c472975eaaefd69 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | bfc0b7c6891896ee8e9818f22800472510093864 (diff) | |
[mlir][vector] ND vectors linearization pass (#81159)
Common backends (LLVM, SPIR-V) only supports 1D vectors, LLVM conversion
handles ND vectors (N >= 2) as `array<array<... vector>>` and SPIR-V
conversion doesn't handle them at all at the moment. Sometimes it's
preferable to treat multidim vectors as linearized 1D. Add pass to do
this. Only constants and simple elementwise ops are supported for now.
@krzysz00 I've extracted yours result type conversion code from
LegalizeToF32 and moved it to common place.
Also, add ConversionPattern class operating on traits.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 126d65b1b848..acd38980514a 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -823,6 +823,33 @@ struct TestVectorEmulateMaskedLoadStore final (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; + +struct TestVectorLinearize final + : public PassWrapper<TestVectorLinearize, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<vector::VectorDialect>(); + } + + void runOnOperation() override { + auto *context = &getContext(); + + TypeConverter typeConverter; + RewritePatternSet patterns(context); + ConversionTarget target(*context); + + vector::populateVectorLinearizeTypeConversionsAndLegality(typeConverter, + patterns, target); + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; } // namespace namespace mlir { @@ -867,6 +894,8 @@ void registerTestVectorLowerings() { PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); PassRegistration<TestVectorEmulateMaskedLoadStore>(); + + PassRegistration<TestVectorLinearize>(); } } // namespace test } // namespace mlir |
