summaryrefslogtreecommitdiff
diff options
context:
space:
mode:
-rw-r--r--mlir/include/mlir/Dialect/Vector/IR/VectorOps.td1
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp4
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp193
-rw-r--r--mlir/test/Dialect/Vector/vector-unroll-options.mlir79
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp22
5 files changed, 297 insertions, 2 deletions
diff --git a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
index f91d2b6404c9..43ebcaa03a47 100644
--- a/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
+++ b/mlir/include/mlir/Dialect/Vector/IR/VectorOps.td
@@ -2424,6 +2424,7 @@ def Vector_CompressStoreOp :
def Vector_ShapeCastOp :
Vector_Op<"shape_cast", [Pure,
+ DeclareOpInterfaceMethods<VectorUnrollOpInterface, ["getShapeForUnroll"]>,
DeclareOpInterfaceMethods<InferIntRangeInterface, ["inferResultRanges"]>
]>,
Arguments<(ins AnyVectorOfAnyRank:$source)>,
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index a97d0cd7f755..2789f6355552 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6243,6 +6243,10 @@ void ShapeCastOp::inferResultRanges(ArrayRef<ConstantIntRanges> argRanges,
setResultRanges(getResult(), argRanges.front());
}
+std::optional<SmallVector<int64_t, 4>> ShapeCastOp::getShapeForUnroll() {
+ return llvm::to_vector<4>(getResultVectorType().getShape());
+}
+
LogicalResult ShapeCastOp::verify() {
VectorType sourceType = getSourceVectorType();
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
index fbae0989bed2..b60f80534bfb 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorUnroll.cpp
@@ -1003,6 +1003,195 @@ private:
vector::UnrollVectorOptions options;
};
+/// Checks whether extractShape is a contiguous slice of shape.
+/// For extractShape to be contiguous in shape:
+/// 1) All but the leading dimension of extractShape and shape must match
+/// exactly. 2) The total number of elements in shape must be evenly divisible
+/// by
+/// the total number of elements in extractShape.
+/// Examples:
+/// isContiguous([4, 4], [8, 4]) == true
+/// isContiguous([2, 4], [8, 4]) == true
+/// isContiguous([2, 2], [8, 4]) == false
+/// Removes leading unit dimensions to handle cases like:
+/// isContiguous([1, 16], [1, 32]) == true
+static bool isContiguous(ArrayRef<int64_t> extractShape,
+ ArrayRef<int64_t> shape) {
+
+ if (extractShape.size() > shape.size())
+ return false;
+
+ while (!extractShape.empty() && extractShape.front() == 1) {
+ extractShape = extractShape.drop_front();
+ }
+
+ while (!shape.empty() && shape.front() == 1) {
+ shape = shape.drop_front();
+ }
+
+ size_t rankDiff = shape.size() - extractShape.size();
+ if (!llvm::equal(extractShape.drop_front(), shape.drop_front(rankDiff + 1)))
+ return false;
+
+ int64_t extractElements = ShapedType::getNumElements(extractShape);
+ int64_t shapeElements = ShapedType::getNumElements(shape);
+ return shapeElements % extractElements == 0;
+}
+
+/// Determines what shape to use with `vector.extract_strided_slice` to extract
+/// a contiguous memory region from a source vector. The extraction must be
+/// contiguous and contain exactly the specified number of elements. If such an
+/// extraction shape cannot be determined, returns std::nullopt.
+/// EXAMPLE 1:
+/// sourceShape = [16], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 16) = 8 from only dim → extractShape = [8],
+/// remaining = 8/8 = 1
+/// Result: [8]
+///
+/// EXAMPLE 2:
+/// sourceShape = [4, 4], targetElements = 8
+/// Working right-to-left:
+/// - Take min(8, 4) = 4 from last dim → extractShape = [4],
+/// remaining = 8/4 = 2
+/// - Take min(2, 4) = 2 from first dim → extractShape = [2, 4],
+/// remaining = 2/2 = 1
+/// Result: [2, 4]
+static std::optional<SmallVector<int64_t>>
+calculateSourceExtractShape(ArrayRef<int64_t> sourceShape,
+ int64_t targetElements) {
+ SmallVector<int64_t> extractShape;
+ int64_t remainingElements = targetElements;
+
+ // Build extract shape from innermost dimension outward to ensure contiguity.
+ for (int i = sourceShape.size() - 1; i >= 0 && remainingElements > 1; --i) {
+ int64_t takeFromDim = std::min(remainingElements, sourceShape[i]);
+ extractShape.insert(extractShape.begin(), takeFromDim);
+
+ if (remainingElements % takeFromDim != 0)
+ return std::nullopt; // Not evenly divisible.
+ remainingElements /= takeFromDim;
+ }
+
+ // Fill remaining dimensions with 1.
+ while (extractShape.size() < sourceShape.size())
+ extractShape.insert(extractShape.begin(), 1);
+
+ if (ShapedType::getNumElements(extractShape) != targetElements)
+ return std::nullopt;
+
+ return extractShape;
+}
+
+// Convert result offsets to source offsets via linear position.
+static SmallVector<int64_t>
+calculateSourceOffsets(ArrayRef<int64_t> resultOffsets,
+ ArrayRef<int64_t> sourceShape,
+ ArrayRef<int64_t> resultShape) {
+ // Convert result offsets to linear position.
+ int64_t linearIndex = linearize(resultOffsets, computeStrides(resultShape));
+ // Convert linear position to source offsets.
+ return delinearize(linearIndex, computeStrides(sourceShape));
+}
+
+/// This pattern unrolls `vector.shape_cast` operations according to the
+/// provided target unroll shape. It unrolls a large shape cast into smaller
+/// shape casts by extracting contiguous slices from the source vector, casting
+/// each slice to the target shape, and assembling the result by inserting each
+/// computed segment into the appropriate offset of the result vector.
+///
+/// This pattern only applies when contiguous slices can be extracted from the
+/// source vector and inserted into the result vector such that each slice
+/// remains a valid vector (and not decompose to scalars). In these cases, the
+/// unrolling proceeds as:
+/// vector.extract_strided_slice -> vector.shape_cast (on the slice) ->
+/// vector.insert_strided_slice.
+///
+/// Example:
+/// Given a shape cast operation:
+/// %0 = vector.shape_cast %src : vector<8x2xf32> to vector<4x4xf32>
+///
+/// and a target unroll shape of <2x4>, the pattern produces:
+///
+/// %zero = arith.constant dense<0.0> : vector<4x4xf32>
+/// %s0 = vector.extract_strided_slice %src [0, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc0 = vector.shape_cast %s0 : vector<4x2xf32> to vector<2x4xf32>
+/// %i0 = vector.insert_strided_slice %sc0, %zero [0, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+/// %s1 = vector.extract_strided_slice %src [4, 0], [4, 2], [1, 1]
+/// : vector<8x2xf32> to vector<4x2xf32>
+/// %sc1 = vector.shape_cast %s1 : vector<4x2xf32> to vector<2x4xf32>
+/// %i1 = vector.insert_strided_slice %sc1, %i0 [2, 0], [1, 1]
+/// : vector<2x4xf32> into vector<4x4xf32>
+///
+struct UnrollShapeCastPattern : public OpRewritePattern<vector::ShapeCastOp> {
+ UnrollShapeCastPattern(MLIRContext *context,
+ const vector::UnrollVectorOptions &options,
+ PatternBenefit benefit = 1)
+ : OpRewritePattern<vector::ShapeCastOp>(context, benefit),
+ options(options) {}
+
+ LogicalResult matchAndRewrite(vector::ShapeCastOp shapeCastOp,
+ PatternRewriter &rewriter) const override {
+ std::optional<SmallVector<int64_t>> targetShape =
+ getTargetShape(options, shapeCastOp);
+ if (!targetShape)
+ return failure();
+
+ VectorType sourceType = shapeCastOp.getSourceVectorType();
+ VectorType resultType = shapeCastOp.getResultVectorType();
+ ArrayRef<int64_t> sourceShape = sourceType.getShape();
+ ArrayRef<int64_t> resultShape = resultType.getShape();
+
+ if (!isContiguous(*targetShape, resultShape))
+ return rewriter.notifyMatchFailure(
+ shapeCastOp, "Only supports cases where target shape is "
+ "contiguous in result vector shape");
+
+ int64_t targetElements = ShapedType::getNumElements(*targetShape);
+
+ // Calculate the shape to extract from source.
+ std::optional<SmallVector<int64_t>> extractShape =
+ calculateSourceExtractShape(sourceShape, targetElements);
+ if (!extractShape)
+ return rewriter.notifyMatchFailure(
+ shapeCastOp,
+ "cannot extract target number of elements contiguously from source");
+
+ Location loc = shapeCastOp.getLoc();
+
+ // Create result vector initialized to zero.
+ Value result = arith::ConstantOp::create(rewriter, loc, resultType,
+ rewriter.getZeroAttr(resultType));
+
+ VectorType targetType =
+ VectorType::get(*targetShape, sourceType.getElementType());
+
+ SmallVector<int64_t> extractStrides(extractShape->size(), 1);
+ SmallVector<int64_t> insertStrides(targetShape->size(), 1);
+
+ for (SmallVector<int64_t> resultOffsets :
+ StaticTileOffsetRange(resultShape, *targetShape)) {
+ SmallVector<int64_t> sourceOffsets =
+ calculateSourceOffsets(resultOffsets, sourceShape, resultShape);
+ Value sourceChunk = rewriter.createOrFold<vector::ExtractStridedSliceOp>(
+ loc, shapeCastOp.getSource(), sourceOffsets, *extractShape,
+ extractStrides);
+ Value targetChunk = rewriter.createOrFold<vector::ShapeCastOp>(
+ loc, targetType, sourceChunk);
+ result = rewriter.createOrFold<vector::InsertStridedSliceOp>(
+ loc, targetChunk, result, resultOffsets, insertStrides);
+ }
+
+ rewriter.replaceOp(shapeCastOp, result);
+ return success();
+ }
+
+private:
+ vector::UnrollVectorOptions options;
+};
+
} // namespace
void mlir::vector::populateVectorUnrollPatterns(
@@ -1013,8 +1202,8 @@ void mlir::vector::populateVectorUnrollPatterns(
UnrollReductionPattern, UnrollMultiReductionPattern,
UnrollTransposePattern, UnrollGatherPattern, UnrollLoadPattern,
UnrollStorePattern, UnrollBroadcastPattern, UnrollFromElements,
- UnrollToElements, UnrollStepPattern>(patterns.getContext(),
- options, benefit);
+ UnrollToElements, UnrollStepPattern, UnrollShapeCastPattern>(
+ patterns.getContext(), options, benefit);
}
void mlir::vector::populateVectorToElementsUnrollPatterns(
diff --git a/mlir/test/Dialect/Vector/vector-unroll-options.mlir b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
index e5a98b5c67f3..dec32e1c61a9 100644
--- a/mlir/test/Dialect/Vector/vector-unroll-options.mlir
+++ b/mlir/test/Dialect/Vector/vector-unroll-options.mlir
@@ -496,3 +496,82 @@ func.func @elementwise_4D_to_2D(%v1: vector<2x2x2x2xf32>, %v2: vector<2x2x2x2xf3
// CHECK-COUNT-4: arith.addf %{{.*}}, %{{.*}} : vector<2x2xf32>
// CHECK-NOT: arith.addf
// CHECK: return
+
+
+func.func @shape_cast_1D(%v: vector<16xf32>) -> vector<2x2x4xf32> {
+ %0 = vector.shape_cast %v : vector<16xf32> to vector<2x2x4xf32>
+ return %0 : vector<2x2x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_1D
+// CHECK-SAME: (%[[V:.*]]: vector<16xf32>) -> vector<2x2x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<2x2x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [8], sizes = [8], strides = [1]} : vector<16xf32> to vector<8xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<8xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [1, 0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<2x2x4xf32>
+// CHECK: return %[[I1]] : vector<2x2x4xf32>
+
+
+func.func @shape_cast_2D(%v: vector<8x2xf32>) -> vector<4x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x2xf32> to vector<4x4xf32>
+ return %0 : vector<4x4xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_2D
+// CHECK-SAME: (%[[V:.*]]: vector<8x2xf32>) -> vector<4x4xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<4x4xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [4, 0], sizes = [4, 2], strides = [1, 1]} : vector<8x2xf32> to vector<4x2xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<4x2xf32> to vector<2x4xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [2, 0], strides = [1, 1]} : vector<2x4xf32> into vector<4x4xf32>
+// CHECK: return %[[I1]] : vector<4x4xf32>
+
+
+// This is a negative test case to ensure that such shape casts are not unrolled
+// because the targetShape (2x4) is not contiguous in result vector
+func.func @negative_shape_cast_target_shape_not_contiguous(%v: vector<64xf32>) -> vector<8x8xf32> {
+ %0 = vector.shape_cast %v : vector<64xf32> to vector<8x8xf32>
+ return %0 : vector<8x8xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_target_shape_not_contiguous
+// CHECK-SAME: (%[[V:.*]]: vector<64xf32>) -> vector<8x8xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<64xf32> to vector<8x8xf32>
+// CHECK: return %[[SC]] : vector<8x8xf32>
+
+
+// This is negative test case to ensure that such shape casts are not unrolled
+// because it cannot determine the extractShape from source vector (8x3)
+// to extract conitguous targetShape (2x4)
+func.func @negative_shape_cast_source_shape_not_determinable(%v: vector<8x3xf32>) -> vector<6x4xf32> {
+ %0 = vector.shape_cast %v : vector<8x3xf32> to vector<6x4xf32>
+ return %0 : vector<6x4xf32>
+}
+
+// CHECK-LABEL: func @negative_shape_cast_source_shape_not_determinable
+// CHECK-SAME: (%[[V:.*]]: vector<8x3xf32>) -> vector<6x4xf32> {
+// CHECK: %[[SC:.*]] = vector.shape_cast %[[V]] : vector<8x3xf32> to vector<6x4xf32>
+// CHECK: return %[[SC]] : vector<6x4xf32>
+
+
+// TargetShape is [1x16]
+func.func @shape_cast_leading_unit_dim(%v: vector<32xf32>) -> vector<1x32xf32> {
+ %0 = vector.shape_cast %v : vector<32xf32> to vector<1x32xf32>
+ return %0 : vector<1x32xf32>
+}
+
+// CHECK-LABEL: func @shape_cast_leading_unit_dim
+// CHECK-SAME: (%[[V:.*]]: vector<32xf32>) -> vector<1x32xf32> {
+// CHECK: %[[CST:.*]] = arith.constant dense<0.000000e+00> : vector<1x32xf32>
+// CHECK: %[[S0:.*]] = vector.extract_strided_slice %[[V]] {offsets = [0], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+// CHECK: %[[SC0:.*]] = vector.shape_cast %[[S0]] : vector<16xf32> to vector<1x16xf32>
+// CHECK: %[[I0:.*]] = vector.insert_strided_slice %[[SC0]], %[[CST]] {offsets = [0, 0], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+// CHECK: %[[S1:.*]] = vector.extract_strided_slice %[[V]] {offsets = [16], sizes = [16], strides = [1]} : vector<32xf32> to vector<16xf32>
+// CHECK: %[[SC1:.*]] = vector.shape_cast %[[S1]] : vector<16xf32> to vector<1x16xf32>
+// CHECK: %[[I1:.*]] = vector.insert_strided_slice %[[SC1]], %[[I0]] {offsets = [0, 16], strides = [1, 1]} : vector<1x16xf32> into vector<1x32xf32>
+// CHECK: return %[[I1]] : vector<1x32xf32>
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 79bfc9bbcda7..e8ea0cc02d7f 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -179,6 +179,28 @@ struct TestVectorUnrollingPatterns
return success(isa<vector::StepOp>(op));
}));
populateVectorUnrollPatterns(
+ patterns,
+ UnrollVectorOptions()
+ .setNativeShapeFn(
+ [](Operation *op) -> std::optional<SmallVector<int64_t>> {
+ auto shapeCast = dyn_cast<vector::ShapeCastOp>(op);
+ if (!shapeCast)
+ return std::nullopt;
+
+ auto resultShape = shapeCast.getResultVectorType().getShape();
+ // Special case with leading unit dims and different inner dim
+ // for result and target shape.
+ if (resultShape.size() == 2 && resultShape[0] == 1 &&
+ resultShape[1] == 32) {
+ return SmallVector<int64_t>{1, 16};
+ }
+ // Default case: [2,4] for all tests.
+ return SmallVector<int64_t>{2, 4};
+ })
+ .setFilterConstraint([](Operation *op) {
+ return success(isa<vector::ShapeCastOp>(op));
+ }));
+ populateVectorUnrollPatterns(
patterns, UnrollVectorOptions()
.setNativeShape(ArrayRef<int64_t>{1, 3, 4, 2})
.setFilterConstraint([](Operation *op) {