summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp')
-rw-r--r--mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp133
1 files changed, 132 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
index 2bea083ac2d7..6984bc2dff49 100644
--- a/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
+++ b/mlir/lib/Dialect/Linalg/Transforms/DataLayoutPropagation.cpp
@@ -17,6 +17,8 @@
#include "mlir/Dialect/Utils/IndexingUtils.h"
#include "mlir/IR/Dominance.h"
#include "mlir/Transforms/GreedyPatternRewriteDriver.h"
+#include "llvm/ADT/SetOperations.h"
+#include "llvm/ADT/SetVector.h"
#include "llvm/ADT/TypeSwitch.h"
#include "llvm/Support/Debug.h"
#include <optional>
@@ -603,7 +605,8 @@ static bool isDimsDivisibleByTileSizes(ArrayRef<int64_t> dimsPos,
static int64_t applyPermutationAndReindexReassoc(
SmallVector<ReassociationIndices> &reassocIndices,
ArrayRef<int64_t> permutation) {
- applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
+ if (!permutation.empty())
+ applyPermutationToVector<ReassociationIndices>(reassocIndices, permutation);
int64_t nextPos = 0;
for (ReassociationIndices &indices : reassocIndices) {
for (auto &index : indices) {
@@ -694,6 +697,131 @@ bubbleUpPackOpThroughCollapseShape(tensor::CollapseShapeOp collapseOp,
return success();
}
+/// Project dimsPos to their collapsed positions in the reassocIndices.
+///
+/// For example, given dimsPos [0, 1, 2, 4], and matching reassocIndices
+/// [[0], [1, 2], [3], [4]], it returns [0, 1, 1, 3]. Because for pos 0,
+/// the reassoc dim [0] is 0. For pos 1 and 2, the reassoc dim in pos
+/// [1, 2] is 1. And for pos 4, the reassoc dim [4] is 3.
+static SmallVector<int64_t>
+projectDimsPosIntoReassocPos(ArrayRef<int64_t> dimsPos,
+ ArrayRef<ReassociationIndices> reassocIndices) {
+ SmallVector<int64_t> projectedPos;
+
+ // Map each dimension to the position of corresponding reassociation index.
+ for (auto pos : dimsPos) {
+ for (auto [idx, indices] : llvm::enumerate(reassocIndices)) {
+ // If the dimension is present in the current indices group, the group
+ // position within the reassociation map is the desired projected
+ // dimension position.
+ if (llvm::any_of(indices,
+ [&](int64_t expandDim) { return expandDim == pos; })) {
+ projectedPos.push_back(idx);
+ break;
+ }
+ }
+ }
+ assert(projectedPos.size() == dimsPos.size() && "Invalid dim pos projection");
+
+ return projectedPos;
+}
+
+/// Bubble up pack op through expand shape op.
+///
+/// For example:
+///
+/// %expand = tensor.expand_shape %in [[0], [1, 2]]
+/// : tensor<?x64xf32> into tensor<?x4x16xf32>
+/// %pack = tensor.pack %expand outer_dims_perm = [0, 1]
+/// inner_dims_pos = [2] inner_tiles = [8] into %empty
+/// : tensor<?x4x16xf32> -> tensor<?x4x2x8xf32>
+///
+/// can be transformed into:
+///
+/// %pack = tensor.pack %in outer_dims_perm = [1, 2]
+/// inner_dims_pos = [1] inner_tiles = [8] into %empty
+/// : tensor<?x64xf32> -> tensor<?x8x8xf32>
+/// %expand = tensor.expand_shape %pack [[0], [1, 2], [3]]
+/// : tensor<?x8x8xf32> into tensor<?x4x2x8xf32>
+static LogicalResult
+bubbleUpPackOpThroughExpandShape(tensor::ExpandShapeOp expandOp,
+ tensor::PackOp packOp,
+ PatternRewriter &rewriter) {
+ // Outer dimensions permutation is not supported currently.
+ // TODO: Handle outer_dims_perm variants.
+ ArrayRef<int64_t> outerDimsPerm = packOp.getOuterDimsPerm();
+ if (!outerDimsPerm.empty() && !isIdentityPermutation(outerDimsPerm)) {
+ return rewriter.notifyMatchFailure(packOp,
+ "non-identity outer dims perm NYI");
+ }
+
+ // Validate dimensions' relations between shape expansion and packing.
+ SmallVector<ReassociationIndices, 4> reassoc =
+ expandOp.getReassociationIndices();
+ ArrayRef<int64_t> packInnerDims = packOp.getInnerDimsPos();
+ llvm::SetVector<int64_t> packDimsPos(packInnerDims.begin(),
+ packInnerDims.end());
+
+ for (auto [idx, indices] : llvm::enumerate(reassoc)) {
+ // For each expand_shape reassociation, figure out which dimensions get
+ // packed if any.
+ llvm::SetVector<int64_t> expandDimPos(indices.begin(), indices.end());
+ llvm::SetVector<int64_t> packedDims =
+ llvm::set_intersection(packDimsPos, expandDimPos);
+
+ // The expanded dimension is not packed so, it does not affect moving pack
+ // before shape expansion - simply continue.
+ if (packedDims.empty())
+ continue;
+ // Shape expansion cannot be propagated when multiple expanded dimension are
+ // packed - in this case operation reordering would affect final element
+ // positions and/or shapes can no longer be projected.
+ if (packedDims.size() != 1)
+ return rewriter.notifyMatchFailure(
+ packOp, "only one of the expanded dimensions can be packed");
+ // Only the inner-most expanded dimension should be packed. Otherwise,
+ // elements order will be affected after operation reordering.
+ if (packedDims.front() != indices.back())
+ return rewriter.notifyMatchFailure(
+ packOp, "can only pack the inner-most expanded dimension");
+ }
+
+ // Project pack.inner_dims_pos to positions before shape expansion.
+ SmallVector<int64_t> projectedInnerDimsPos =
+ projectDimsPosIntoReassocPos(packInnerDims, reassoc);
+
+ // Project the shape expansion to new packed shape.
+ // The pack.outer_dims_perm is restricted to identity so, the permutation can
+ // be omitted for simplicity.
+ // TODO: Account for outer dimensions permutation.
+ //
+ // If reassociation is not possible, then reordering cannot happen.
+ // This can be caused by pack padding affecting previously expanded
+ // dimensions or packing extending dimensions.
+ RankedTensorType newPackType = tensor::PackOp::inferPackedType(
+ expandOp.getSrcType(), packOp.getStaticInnerTiles(),
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ auto reassocExpand =
+ getReassociationIndicesForReshape(newPackType, packOp.getDestType());
+ if (!reassocExpand)
+ return rewriter.notifyMatchFailure(
+ packOp, "could not reassociate dims after bubbling up");
+
+ Value destTensor = tensor::PackOp::createDestinationTensor(
+ rewriter, packOp.getLoc(), expandOp.getSrc(), packOp.getMixedTiles(),
+ projectedInnerDimsPos, /*outerDimsPerm=*/SmallVector<int64_t>{});
+ Value packedVal = rewriter.create<tensor::PackOp>(
+ packOp.getLoc(), expandOp.getSrc(), destTensor, projectedInnerDimsPos,
+ packOp.getMixedTiles(), packOp.getPaddingValue(),
+ /*outerDimsPerm=*/SmallVector<int64_t>{});
+
+ Value newExpandOp = rewriter.create<tensor::ExpandShapeOp>(
+ packOp.getLoc(), packOp.getDestType(), packedVal, *reassocExpand);
+ rewriter.replaceOp(packOp, newExpandOp);
+
+ return success();
+}
+
class BubbleUpPackOpThroughReshapeOp final
: public OpRewritePattern<tensor::PackOp> {
public:
@@ -723,6 +851,9 @@ public:
.Case([&](tensor::CollapseShapeOp op) {
return bubbleUpPackOpThroughCollapseShape(op, packOp, rewriter);
})
+ .Case([&](tensor::ExpandShapeOp op) {
+ return bubbleUpPackOpThroughExpandShape(op, packOp, rewriter);
+ })
.Default([](Operation *) { return failure(); });
}