From 9b06e25e73470612d14f0e1e18fde82f62266216 Mon Sep 17 00:00:00 2001 From: Benjamin Maxwell Date: Fri, 9 Aug 2024 10:51:49 +0100 Subject: [mlir][vector] Add mask elimination transform (#99314) This adds a new transform `eliminateVectorMasks()` which aims at removing scalable `vector.create_masks` that will be all-true at runtime. It attempts to do this by simply pattern-matching the mask operands (similar to some canonicalizations), if that does not lead to an answer (is all-true? yes/no), then value bounds analysis will be used to find the lower bound of the unknown operands. If the lower bound is >= to the corresponding mask vector type dim, then that dimension of the mask is all true. Note that the pattern matching prevents expensive value-bounds analysis in cases where the mask won't be all true. For example: ```mlir %mask = vector.create_mask %dynamicValue, %c2 : vector<8x4xi1> ``` From looking at `%c2` we can tell this is not going to be an all-true mask, so we don't need to run the value-bounds analysis for `%dynamicValue` (and can exit the transform early). Note: Eliminating create_masks here means replacing them with all-true constants (which will then lead to the masks folding away). --- .../lib/Dialect/Vector/TestVectorTransforms.cpp | 29 ++++++++++++++++++++++ 1 file changed, 29 insertions(+) (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp') diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 592e24af94d6..29c763b622e8 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -873,6 +873,33 @@ struct TestVectorLinearize final return signalPassFailure(); } }; + +struct TestEliminateVectorMasks + : public PassWrapper> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestEliminateVectorMasks) + + TestEliminateVectorMasks() = default; + TestEliminateVectorMasks(const TestEliminateVectorMasks &pass) + : PassWrapper(pass) {} + + Option vscaleMin{ + *this, "vscale-min", llvm::cl::desc("Minimum possible value of vscale."), + llvm::cl::init(1)}; + Option vscaleMax{ + *this, "vscale-max", llvm::cl::desc("Maximum possible value of vscale."), + llvm::cl::init(16)}; + + StringRef getArgument() const final { return "test-eliminate-vector-masks"; } + StringRef getDescription() const final { + return "Test eliminating vector masks"; + } + void runOnOperation() override { + IRRewriter rewriter(&getContext()); + eliminateVectorMasks(rewriter, getOperation(), + VscaleRange{vscaleMin, vscaleMax}); + } +}; } // namespace namespace mlir { @@ -919,6 +946,8 @@ void registerTestVectorLowerings() { PassRegistration(); PassRegistration(); + + PassRegistration(); } } // namespace test } // namespace mlir -- cgit v1.2.3