summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorHsiangkai Wang <hsiangkai.wang@arm.com>2023-12-15 11:35:48 +0000
committerGitHub <noreply@github.com>2023-12-15 11:35:48 +0000
commitf643eec892954653b1c9bde42407560caf660b8b (patch)
tree5425bd5457615c420fecbd9ebc7ebcea64110aa5 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parentef067f52044042fbe1b6fa21a90bfdbcf1622b02 (diff)
[mlir][vector] Add emulation patterns for vector masked load/store (#74834)
In this patch, it will convert ``` vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru ``` to ``` %ivalue = %pass_thru %m = vector.extract %mask[0] %result0 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1] %combined = vector.insert %v, %ivalue[0] scf.yield %combined } else { scf.yield %ivalue } %m = vector.extract %mask[1] %result1 = scf.if %m { %v = memref.load %base[%idx_0, %idx_1 + 1] %combined = vector.insert %v, %result0[1] scf.yield %combined } else { scf.yield %result0 } ... ``` It will convert ``` vector.maskedstore %base[%idx_0, %idx_1], %mask, %value ``` to ``` %m = vector.extract %mask[0] scf.if %m { %extracted = vector.extract %value[0] memref.store %extracted, %base[%idx_0, %idx_1] } %m = vector.extract %mask[1] scf.if %m { %extracted = vector.extract %value[1] memref.store %extracted, %base[%idx_0, %idx_1 + 1] } ... ```
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp27
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index a643343e9342..03ddebe82344 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -777,6 +777,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns
(void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
}
};
+
+struct TestVectorEmulateMaskedLoadStore final
+ : public PassWrapper<TestVectorEmulateMaskedLoadStore,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore)
+
+ StringRef getArgument() const override {
+ return "test-vector-emulate-masked-load-store";
+ }
+ StringRef getDescription() const override {
+ return "Test patterns that emulate the maskedload/maskedstore op by "
+ " memref.load/store and scf.if";
+ }
+ void getDependentDialects(DialectRegistry &registry) const override {
+ registry
+ .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect,
+ scf::SCFDialect, vector::VectorDialect>();
+ }
+
+ void runOnOperation() override {
+ RewritePatternSet patterns(&getContext());
+ populateVectorMaskedLoadStoreEmulationPatterns(patterns);
+ (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns));
+ }
+};
} // namespace
namespace mlir {
@@ -817,6 +842,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorGatherLowering>();
PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>();
+
+ PassRegistration<TestVectorEmulateMaskedLoadStore>();
}
} // namespace test
} // namespace mlir