summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tensor/IR/TensorOps.cpp')
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp123
1 files changed, 99 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index f79c774ceb3e..24a1d5531531 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op,
return newOperands;
}
+// Given the (potentially) updated packed type, `newPackedTy`, generates an
+// updated mixed-tile-sizes attribute. A tile size is updated only
+// when:
+// * a dim from newPackedTy is static, and
+// * the corresponding size from mixedTiles is still dynamic.
+// Otherwise, the original tile size is preserved.
+// Note - packed-type-dim and mixed-tile-size should always match!
+static SmallVector<OpFoldResult>
+getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy,
+ SmallVector<OpFoldResult> mixedTiles) {
+ SmallVector<OpFoldResult> newMixedTileSizes;
+ for (auto it : llvm::zip(cast<ShapedType>(newPackedTy)
+ .getShape()
+ .take_back(mixedTiles.size()),
+ mixedTiles)) {
+ int64_t shape = std::get<0>(it);
+ if (shape == ShapedType::kDynamic) {
+ newMixedTileSizes.push_back(std::get<1>(it));
+ continue;
+ }
+
+ // If the current result dim is static, update the dynamic mixed-size
+ // (provided the original value is dynamic).
+ OpFoldResult tile = std::get<1>(it);
+ if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) {
+ // Already a constant
+ newMixedTileSizes.push_back(tile);
+ } else {
+ assert(getConstantIntValue(tile).value() == shape &&
+ "tile size and dim size don't match!");
+ newMixedTileSizes.push_back(
+ (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
+ }
+ }
+
+ return newMixedTileSizes;
+}
+
/// Folds a tensor.cast op into a consuming tensor::PackOp op if the
/// `tensor.cast` has source that is more static than the consuming op.
///
@@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
// Get the updated mixed-tile-sizes attribute.
- SmallVector<OpFoldResult> newMixedTileSizes;
- for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0])
- .getShape()
- .take_back(op.getMixedTiles().size()),
- op.getMixedTiles())) {
- int64_t shape = std::get<0>(it);
- if (shape == ShapedType::kDynamic) {
- newMixedTileSizes.push_back(std::get<1>(it));
- continue;
- }
-
- if (Attribute attr =
- llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) {
- // Already a constant
- newMixedTileSizes.push_back(std::get<1>(it));
- } else {
- int64_t tileSize = getConstantIntValue(std::get<1>(it)).value();
- assert(tileSize == shape && "tile size and dim size don't match!");
- (void)tileSize;
- newMixedTileSizes.push_back(
- (rewriter.getIntegerAttr(rewriter.getIndexType(), shape)));
- }
- }
+ SmallVector<OpFoldResult> newMixedTileSizes =
+ getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles());
// Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
PackOp newOp = rewriter.create<PackOp>(
op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(),
newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm());
@@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> {
}
};
+/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the
+/// `tensor.cast` has source that is more static than the consuming op.
+///
+/// Example:
+/// ```mlir
+/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32>
+/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32>
+/// ```
+///
+/// folds into:
+///
+/// ```mlir
+/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32>
+/// ```
+struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> {
+ using OpRewritePattern<UnPackOp>::OpRewritePattern;
+
+ LogicalResult matchAndRewrite(UnPackOp op,
+ PatternRewriter &rewriter) const override {
+ if (!foldTensorCastPrecondition(op))
+ return failure();
+
+ SmallVector<Type> newResultTypes(op->getResultTypes());
+ SmallVector<Value> newOperands = getNewOperands(op, newResultTypes);
+ Value sourceTensor = newOperands[0];
+
+ // Get the updated mixed-tile-sizes attribute.
+ SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes(
+ rewriter, sourceTensor.getType(), op.getMixedTiles());
+
+ // Clone op.
+ // TODO: Strictly speaking, discardable attributes should be _discarded_ at
+ // this point. However, in practice, we use them for things that we'd like
+ // to preserve. Implement a better abstraction.
+ UnPackOp newOp = rewriter.create<UnPackOp>(
+ op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(),
+ newMixedTileSizes, op.getOuterDimsPerm());
+ newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary());
+
+ // Replace op.
+ Value oldResult = op.getResult();
+ Value newResult = newOp.getResult();
+ Value replacement = (newResult.getType() != oldResult.getType())
+ ? rewriter.create<tensor::CastOp>(
+ op->getLoc(), oldResult.getType(), newResult)
+ : newResult;
+
+ rewriter.replaceOp(op, {replacement});
+
+ return success();
+ }
+};
+
/// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if
/// the `tensor.cast` has source that is more static than the consuming op.
///
@@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp
PatternRewriter &rewriter) const override {
// Reject tensor::PackOp - there's dedicated pattern for that instead.
- if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op))
+ if (!foldTensorCastPrecondition(op) ||
+ isa<tensor::PackOp, tensor::UnPackOp>(*op))
return failure();
SmallVector<Type> newResultTypes(op->getResultTypes());
@@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp
void TensorDialect::getCanonicalizationPatterns(
RewritePatternSet &results) const {
results.add<FoldTensorCastPackOp>(getContext());
+ results.add<FoldTensorCastUnPackOp>(getContext());
results.add<FoldTensorCastProducerOp>(getContext());
}