summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorBenjamin Maxwell <benjamin.maxwell@arm.com>2024-08-09 10:51:49 +0100
committerGitHub <noreply@github.com>2024-08-09 10:51:49 +0100
commit9b06e25e73470612d14f0e1e18fde82f62266216 (patch)
tree60f6c65270813df53b724a0a91ad79106ee9c564 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parentbadfb4bd33c27f7bb6351ad3d7250a9b8782b43f (diff)
[mlir][vector] Add mask elimination transform (#99314)
This adds a new transform `eliminateVectorMasks()` which aims at removing scalable `vector.create_masks` that will be all-true at runtime. It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands. If the lower bound is >= to the corresponding mask vector type dim, then that dimension of the mask is all true. Note that the pattern matching prevents expensive value-bounds analysis in cases where the mask won't be all true. For example: ```mlir %mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1> ``` From looking at `%c2` we can tell this is not going to be an all-true mask, so we don't need to run the value-bounds analysis for `%dynamicValue` (and can exit the transform early). Note: Eliminating create_masks here means replacing them with all-true constants (which will then lead to the masks folding away).
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp29
1 files changed, 29 insertions, 0 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 592e24af94d6..29c763b622e8 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -873,6 +873,33 @@ struct TestVectorLinearize final
return signalPassFailure();
}
};
+
+struct TestEliminateVectorMasks
+ : public PassWrapper<TestEliminateVectorMasks,
+ OperationPass<func::FuncOp>> {
+ MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks)
+
+ TestEliminateVectorMasks() = default;
+ TestEliminateVectorMasks(const TestEliminateVectorMasks &pass)
+ : PassWrapper(pass) {}
+
+ Option<unsigned> vscaleMin{
+ *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."),
+ llvm::cl::init(1)};
+ Option<unsigned> vscaleMax{
+ *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."),
+ llvm::cl::init(16)};
+
+ StringRef getArgument() const final { return "test-eliminate-vector-masks"; }
+ StringRef getDescription() const final {
+ return "Test eliminating vector masks";
+ }
+ void runOnOperation() override {
+ IRRewriter rewriter(&getContext());
+ eliminateVectorMasks(rewriter, getOperation(),
+ VscaleRange{vscaleMin, vscaleMax});
+ }
+};
} // namespace
namespace mlir {
@@ -919,6 +946,8 @@ void registerTestVectorLowerings() {
PassRegistration<TestVectorEmulateMaskedLoadStore>();
PassRegistration<TestVectorLinearize>();
+
+ PassRegistration<TestEliminateVectorMasks>();
}
} // namespace test
} // namespace mlir