diff options
| author | Hsiangkai Wang <hsiangkai.wang@arm.com> | 2023-12-15 11:35:48 +0000 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2023-12-15 11:35:48 +0000 |
| commit | f643eec892954653b1c9bde42407560caf660b8b (patch) | |
| tree | 5425bd5457615c420fecbd9ebc7ebcea64110aa5 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | ef067f52044042fbe1b6fa21a90bfdbcf1622b02 (diff) | |
[mlir][vector] Add emulation patterns for vector masked load/store (#74834)
In this patch, it will convert
```
vector.maskedload %base[%idx_0, %idx_1], %mask, %pass_thru
```
to
```
%ivalue = %pass_thru
%m = vector.extract %mask[0]
%result0 = scf.if %m {
%v = memref.load %base[%idx_0, %idx_1]
%combined = vector.insert %v, %ivalue[0]
scf.yield %combined
} else {
scf.yield %ivalue
}
%m = vector.extract %mask[1]
%result1 = scf.if %m {
%v = memref.load %base[%idx_0, %idx_1 + 1]
%combined = vector.insert %v, %result0[1]
scf.yield %combined
} else {
scf.yield %result0
}
...
```
It will convert
```
vector.maskedstore %base[%idx_0, %idx_1], %mask, %value
```
to
```
%m = vector.extract %mask[0]
scf.if %m {
%extracted = vector.extract %value[0]
memref.store %extracted, %base[%idx_0, %idx_1]
}
%m = vector.extract %mask[1]
scf.if %m {
%extracted = vector.extract %value[1]
memref.store %extracted, %base[%idx_0, %idx_1 + 1]
}
...
```
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 27 |
1 files changed, 27 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index a643343e9342..03ddebe82344 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -777,6 +777,31 @@ struct TestFoldArithExtensionIntoVectorContractPatterns (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); } }; + +struct TestVectorEmulateMaskedLoadStore final + : public PassWrapper<TestVectorEmulateMaskedLoadStore, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorEmulateMaskedLoadStore) + + StringRef getArgument() const override { + return "test-vector-emulate-masked-load-store"; + } + StringRef getDescription() const override { + return "Test patterns that emulate the maskedload/maskedstore op by " + " memref.load/store and scf.if"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry + .insert<arith::ArithDialect, func::FuncDialect, memref::MemRefDialect, + scf::SCFDialect, vector::VectorDialect>(); + } + + void runOnOperation() override { + RewritePatternSet patterns(&getContext()); + populateVectorMaskedLoadStoreEmulationPatterns(patterns); + (void)applyPatternsAndFoldGreedily(getOperation(), std::move(patterns)); + } +}; } // namespace namespace mlir { @@ -817,6 +842,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorGatherLowering>(); PassRegistration<TestFoldArithExtensionIntoVectorContractPatterns>(); + + PassRegistration<TestVectorEmulateMaskedLoadStore>(); } } // namespace test } // namespace mlir |
