diff options
| author | James Newling <james.newling@gmail.com> | 2025-04-30 09:05:40 -0700 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-04-30 09:05:40 -0700 |
| commit | bad8bf56d3e4f107423b307f5f75564296703a76 (patch) | |
| tree | ef27278550fffb1c23222f3bce5e3fc2fbfb9ad3 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 0e9fb5202ce2c1e3cc43436dd4c4b8cd57fa1cef (diff) | |
[mlir][vector] Linearization: push 'bit width' logic out of patterns (#136581)
[NFC]
Vector linearization is a collection of rewrite patterns that reduce the
rank of vector operands and results.
In https://github.com/llvm/llvm-project/pull/83314 an option to ignore
(make 'legal') operations with large inner-most dimensions was added.
This current PR is a step towards making that option live outside of
upstream MLIR. The motivation is to remove non-core functionality (I
would like to use this pass, but would prefer not to deal with
'targetVectorBitWidth` at all).
As a follow-up to this PR, I propose that user(s) of the
`targetVectorBitWidth` move the relevant code (now in
mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp) to their code
bases, and then eventually remove it from upstream. In addition the tests need to
split out (I've intentionally not modified the lit tests here, to make
it easier to confirm that this is a NFC). I'm happy to help make it
easier to do this final step!
The approach I've used is to move the logic pertaining to
`targetVectorBitWidth` out the patterns, and into the conversion target,
which the end user can control outside of core MLIR.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 149 |
1 files changed, 134 insertions, 15 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 03f907e46c2c..eda2594fbc7c 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -7,17 +7,13 @@ //===----------------------------------------------------------------------===// #include <optional> -#include <type_traits> #include "mlir/Analysis/SliceAnalysis.h" #include "mlir/Dialect/Affine/IR/AffineOps.h" #include "mlir/Dialect/Arith/IR/Arith.h" #include "mlir/Dialect/Func/IR/FuncOps.h" #include "mlir/Dialect/GPU/IR/GPUDialect.h" -#include "mlir/Dialect/LLVMIR/LLVMDialect.h" -#include "mlir/Dialect/Linalg/IR/Linalg.h" #include "mlir/Dialect/Linalg/Passes.h" -#include "mlir/Dialect/Linalg/Transforms/Transforms.h" #include "mlir/Dialect/MemRef/IR/MemRef.h" #include "mlir/Dialect/NVGPU/IR/NVGPUDialect.h" #include "mlir/Dialect/SCF/IR/SCF.h" @@ -840,16 +836,98 @@ struct TestVectorEmulateMaskedLoadStore final } }; -struct TestVectorLinearize final - : public PassWrapper<TestVectorLinearize, OperationPass<>> { - MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) +// TODO: move this code into the user project. +namespace vendor { - TestVectorLinearize() = default; - TestVectorLinearize(const TestVectorLinearize &pass) : PassWrapper(pass) {} +/// Get the set of operand/result types to check for sufficiently +/// small inner-most dimension size. +static SmallVector<std::pair<Type, unsigned>> +getTypeBitWidthBoundPairs(Operation *op, unsigned targetBitWidth) { - StringRef getArgument() const override { return "test-vector-linearize"; } + if (auto insertOp = dyn_cast<vector::InsertOp>(op)) { + unsigned w = targetBitWidth < std::numeric_limits<unsigned>::max() + ? targetBitWidth + 1 + : targetBitWidth; + return {{insertOp.getValueToStoreType(), w}}; + } + + auto resultTypes = op->getResultTypes(); + SmallVector<std::pair<Type, unsigned>> resultsWithBitWidth; + resultsWithBitWidth.reserve(resultTypes.size()); + for (Type type : resultTypes) { + resultsWithBitWidth.push_back({type, targetBitWidth}); + } + return resultsWithBitWidth; +} + +/// If `type` is VectorType with trailing dimension of (bit) size greater than +/// or equal to `targetBitWidth`, its defining op is considered legal. +static bool +isNotLinearizableBecauseLargeInnerDimension(Type type, + unsigned targetBitWidth) { + + VectorType vecType = dyn_cast<VectorType>(type); + + // Not linearizable for reasons other than what this function checks. + if (!vecType || vecType.getRank() == 0) + return false; + + // The width of the type 'index' is unbounded (and therefore potentially above + // the target width). + if (vecType.getElementType().isIndex()) + return true; + + unsigned finalDimSize = vecType.getShape().back(); + unsigned nbBitsPerElm = vecType.getElementTypeBitWidth(); + unsigned trailingVecDimBitWidth = finalDimSize * nbBitsPerElm; + return trailingVecDimBitWidth >= targetBitWidth; +} + +static bool +isNotLinearizableBecauseLargeInnerDimension(Operation *op, + unsigned targetBitWidth) { + // Check on bitwidths. + SmallVector<std::pair<Type, unsigned>> toCheck = + getTypeBitWidthBoundPairs(op, targetBitWidth); + return std::any_of(toCheck.begin(), toCheck.end(), + [&](std::pair<Type, unsigned> typeWidth) { + return isNotLinearizableBecauseLargeInnerDimension( + typeWidth.first, typeWidth.second); + }); +} + +void populateWithBitWidthConstraints(TypeConverter &typeConverter, + ConversionTarget &target, + unsigned targetBitWidth) { + + // The general purpose definition of what ops are legal must come first. + populateForVectorLinearize(typeConverter, target); + + // Extend the set of legal ops to include those with large inner-most + // dimensions on selected operands/results. + target.markUnknownOpDynamicallyLegal( + [=](Operation *op) -> std::optional<bool> { + if (isNotLinearizableBecauseLargeInnerDimension(op, targetBitWidth)) { + return true; + } + return {}; + }); +} + +struct TestVectorBitWidthLinearize final + : public PassWrapper<TestVectorBitWidthLinearize, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorBitWidthLinearize) + + TestVectorBitWidthLinearize() = default; + TestVectorBitWidthLinearize(const TestVectorBitWidthLinearize &pass) + : PassWrapper(pass) {} + + StringRef getArgument() const override { + return "test-bit-width-constrained-vector-linearize"; + } StringRef getDescription() const override { - return "Linearizes ND vectors for N >= 2 into 1D vectors"; + return "Linearizes ND vectors for N >= 2 into 1D vectors, with constraints " + "in inner-most dimension's bit width."; } void getDependentDialects(DialectRegistry ®istry) const override { registry.insert<vector::VectorDialect>(); @@ -867,10 +945,49 @@ struct TestVectorLinearize final RewritePatternSet patterns(context); ConversionTarget target(*context); - vector::populateVectorLinearizeTypeConversionsAndLegality( - typeConverter, patterns, target, targetVectorBitwidth); - vector::populateVectorLinearizeShuffleLikeOpsPatterns( - typeConverter, patterns, target, targetVectorBitwidth); + populateWithBitWidthConstraints(typeConverter, target, + targetVectorBitwidth); + + vector::populateVectorLinearizeBasePatterns(typeConverter, target, + patterns); + + vector::populateVectorLinearizeShuffleLikeOpsPatterns(typeConverter, target, + patterns); + + if (failed(applyPartialConversion(getOperation(), target, + std::move(patterns)))) + return signalPassFailure(); + } +}; + +} // namespace vendor + +struct TestVectorLinearize final + : public PassWrapper<TestVectorLinearize, OperationPass<>> { + MLIR_DEFINE_EXPLICIT_INTERNAL_INLINE_TYPE_ID(TestVectorLinearize) + + TestVectorLinearize() = default; + + StringRef getArgument() const override { return "test-vector-linearize"; } + StringRef getDescription() const override { + return "Linearizes ND vectors for N >= 2 into 1D vectors"; + } + void getDependentDialects(DialectRegistry ®istry) const override { + registry.insert<vector::VectorDialect>(); + } + + void runOnOperation() override { + MLIRContext &context = getContext(); + TypeConverter converter; + RewritePatternSet patterns(&context); + ConversionTarget target(context); + + vector::populateForVectorLinearize(converter, target); + + vector::populateVectorLinearizeBasePatterns(converter, target, patterns); + vector::populateVectorLinearizeShuffleLikeOpsPatterns(converter, target, + patterns); + if (failed(applyPartialConversion(getOperation(), target, std::move(patterns)))) return signalPassFailure(); @@ -950,6 +1067,8 @@ void registerTestVectorLowerings() { PassRegistration<TestVectorLinearize>(); + PassRegistration<vendor::TestVectorBitWidthLinearize>(); + PassRegistration<TestEliminateVectorMasks>(); } } // namespace test |
