diff options
Diffstat (limited to 'mlir/lib/Dialect/Vector/IR/VectorOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Vector/IR/VectorOps.cpp | 70 |
1 files changed, 69 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp index 45c54c7587c6..ad8255a95cb4 100644 --- a/mlir/lib/Dialect/Vector/IR/VectorOps.cpp +++ b/mlir/lib/Dialect/Vector/IR/VectorOps.cpp @@ -6835,6 +6835,73 @@ public: } }; +/// Folds transpose(from_elements(...)) into a new from_elements with permuted +/// operands matching the transposed shape. +/// +/// Example: +/// +/// %v = vector.from_elements %a00, %a01, %a02, %a10, %a11, %a12 : +/// vector<2x3xi32> %t = vector.transpose %v, [1, 0] : vector<2x3xi32> to +/// vector<3x2xi32> +/// +/// becomes -> +/// +/// %r = vector.from_elements %a00, %a10, %a01, %a11, %a02, %a12 : +/// vector<3x2xi32> +/// +class FoldTransposeFromElements final : public OpRewritePattern<TransposeOp> { +public: + using Base::Base; + LogicalResult matchAndRewrite(vector::TransposeOp transposeOp, + PatternRewriter &rewriter) const override { + auto fromElementsOp = + transposeOp.getVector().getDefiningOp<vector::FromElementsOp>(); + if (!fromElementsOp) + return failure(); + + VectorType srcTy = fromElementsOp.getDest().getType(); + VectorType dstTy = transposeOp.getType(); + + ArrayRef<int64_t> permutation = transposeOp.getPermutation(); + int64_t rank = srcTy.getRank(); + + // Build inverse permutation to map destination indices back to source. + SmallVector<int64_t> inversePerm(rank, 0); + for (int64_t i = 0; i < rank; ++i) + inversePerm[permutation[i]] = i; + + ArrayRef<int64_t> srcShape = srcTy.getShape(); + ArrayRef<int64_t> dstShape = dstTy.getShape(); + SmallVector<int64_t> srcIdx(rank, 0); + SmallVector<int64_t> dstIdx(rank, 0); + SmallVector<int64_t> srcStrides = computeStrides(srcShape); + SmallVector<int64_t> dstStrides = computeStrides(dstShape); + + auto elementsOld = fromElementsOp.getElements(); + SmallVector<Value> elementsNew; + int64_t dstNumElements = dstTy.getNumElements(); + elementsNew.reserve(dstNumElements); + + // For each element in destination row-major order, pick the corresponding + // source element. + for (int64_t linearIdx = 0; linearIdx < dstNumElements; ++linearIdx) { + // Pick the destination element index. + dstIdx = delinearize(linearIdx, dstStrides); + // Map the destination element index to the source element index. + for (int64_t j = 0; j < rank; ++j) + srcIdx[j] = dstIdx[inversePerm[j]]; + // Linearize the source element index. + int64_t srcLin = linearize(srcIdx, srcStrides); + // Add the source element to the new elements. + elementsNew.push_back(elementsOld[srcLin]); + } + + rewriter.replaceOpWithNewOp<FromElementsOp>(transposeOp, dstTy, + elementsNew); + return success(); + } +}; + /// Folds transpose(broadcast(x)) to broadcast(x) if the transpose is /// 'order preserving', where 'order preserving' means the flattened /// inputs and outputs of the transpose have identical (numerical) values. @@ -6935,7 +7002,8 @@ public: void vector::TransposeOp::getCanonicalizationPatterns( RewritePatternSet &results, MLIRContext *context) { results.add<FoldTransposeCreateMask, FoldTransposeShapeCast, TransposeFolder, - FoldTransposeSplat, FoldTransposeBroadcast>(context); + FoldTransposeSplat, FoldTransposeFromElements, + FoldTransposeBroadcast>(context); } //===----------------------------------------------------------------------===// |
