summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Vector/IR/VectorOps.cpp')
-rw-r--r--mlir/lib/Dialect/Vector/IR/VectorOps.cpp70
1 files changed, 69 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
index 45c54c7587c6..ad8255a95cb4 100644
--- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
+++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp
@@ -6835,6 +6835,73 @@ public:
}
};
+/// Folds transpose(from_elements(...)) into a new from_elements with permuted
+/// operands matching the transposed shape.
+///
+/// Example:
+///
+/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 :
+/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to
+/// vector<3x2xi32>
+///
+/// becomes ->
+///
+/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 :
+/// vector<3x2xi32>
+///
+class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> {
+public:
+ using Base::Base;
+ LogicalResult matchAndRewrite(vector::TransposeOp transposeOp,
+ PatternRewriter &rewriter) const override {
+ auto fromElementsOp =
+ transposeOp.getVector().getDefiningOp<vector::FromElementsOp>();
+ if (!fromElementsOp)
+ return failure();
+
+ VectorType srcTy = fromElementsOp.getDest().getType();
+ VectorType dstTy = transposeOp.getType();
+
+ ArrayRef<int64_t> permutation = transposeOp.getPermutation();
+ int64_t rank = srcTy.getRank();
+
+ // Build inverse permutation to map destination indices back to source.
+ SmallVector<int64_t> inversePerm(rank, 0);
+ for (int64_t i = 0; i < rank; ++i)
+ inversePerm[permutation[i]] = i;
+
+ ArrayRef<int64_t> srcShape = srcTy.getShape();
+ ArrayRef<int64_t> dstShape = dstTy.getShape();
+ SmallVector<int64_t> srcIdx(rank, 0);
+ SmallVector<int64_t> dstIdx(rank, 0);
+ SmallVector<int64_t> srcStrides = computeStrides(srcShape);
+ SmallVector<int64_t> dstStrides = computeStrides(dstShape);
+
+ auto elementsOld = fromElementsOp.getElements();
+ SmallVector<Value> elementsNew;
+ int64_t dstNumElements = dstTy.getNumElements();
+ elementsNew.reserve(dstNumElements);
+
+ // For each element in destination row-major order, pick the corresponding
+ // source element.
+ for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) {
+ // Pick the destination element index.
+ dstIdx = delinearize(linearIdx, dstStrides);
+ // Map the destination element index to the source element index.
+ for (int64_t j = 0; j < rank; ++j)
+ srcIdx[j] = dstIdx[inversePerm[j]];
+ // Linearize the source element index.
+ int64_t srcLin = linearize(srcIdx, srcStrides);
+ // Add the source element to the new elements.
+ elementsNew.push_back(elementsOld[srcLin]);
+ }
+
+ rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy,
+ elementsNew);
+ return success();
+ }
+};
+
/// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is
/// 'order preserving', where 'order preserving' means the flattened
/// inputs and outputs of the transpose have identical (numerical) values.
@@ -6935,7 +7002,8 @@ public:
void vector::TransposeOp::getCanonicalizationPatterns(
RewritePatternSet &results, MLIRContext *context) {
results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder,
- FoldTransposeSplat, FoldTransposeBroadcast>(context);
+ FoldTransposeSplat, FoldTransposeFromElements,
+ FoldTransposeBroadcast>(context);
}
//===----------------------------------------------------------------------===//