diff options
| author | Benjamin Maxwell <benjamin.maxwell@arm.com> | 2024-08-09 10:51:49 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-08-09 10:51:49 +0100 |
| commit | 9b06e25e73470612d14f0e1e18fde82f62266216 (patch) | |
| tree | 60f6c65270813df53b724a0a91ad79106ee9c564 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | badfb4bd33c27f7bb6351ad3d7250a9b8782b43f (diff) | |
[mlir][vector] Add mask elimination transform (#99314)
This adds a new transform `eliminateVectorMasks()` which aims at
removing scalable `vector.create_masks` that will be all-true at
runtime. It attempts to do this by simply pattern-matching the mask
operands (similar to some canonicalizations), if that does not lead to
an answer (is all-true? yes/no), then value bounds analysis will be used
to find the lower bound of the unknown operands. If the lower bound is
>= to the corresponding mask vector type dim, then that dimension of the
mask is all true.
Note that the pattern matching prevents expensive value-bounds analysis
in cases where the mask won't be all true.
For example:
```mlir
%mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1>
```
From looking at `%c2` we can tell this is not going to be an all-true
mask, so we don't need to run the value-bounds analysis for
`%dynamicValue` (and can exit the transform early).
Note: Eliminating create_masks here means replacing them with all-true
constants (which will then lead to the masks folding away).
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 29 |
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 592e24af94d6..29c763b622e8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -873,6 +873,33 @@ struct TestVectorLinearize final return signalPassFailure(); } }; + +struct TestEliminateVectorMasks + : public PassWrapper<TestEliminateVectorMasks, + OperationPass<func::FuncOp>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) + + TestEliminateVectorMasks() = default; + TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) + : PassWrapper(pass) {} + + Option<unsigned> vscaleMin{ + *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), + llvm::cl::init(1)}; + Option<unsigned> vscaleMax{ + *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), + llvm::cl::init(16)}; + + StringRef getArgument() const final { return "test-eliminate-vector-masks"; } + StringRef getDescription() const final { + return "Test eliminating vector masks"; + } + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + eliminateVectorMasks(rewriter, getOperation(), + VscaleRange{vscaleMin, vscaleMax}); + } +}; } // namespace namespace mlir { @@ -919,6 +946,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorEmulateMaskedLoadStore>(); PassRegistration<TestVectorLinearize>(); + + PassRegistration<TestEliminateVectorMasks>(); } } // namespace test } // namespace mlir |
