diff options
Diffstat (limited to 'mlir/lib/Dialect/Tensor/IR/TensorOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 123 |
1 files changed, 99 insertions, 24 deletions
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index f79c774ceb3e..24a1d5531531 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4795,6 +4795,44 @@ static SmallVector<Value> getNewOperands(DestinationStyleOpInterface op, return newOperands; } +// Given the (potentially) updated packed type, `newPackedTy`, generates an +// updated mixed-tile-sizes attribute. A tile size is updated only +// when: +// * a dim from newPackedTy is static, and +// * the corresponding size from mixedTiles is still dynamic. +// Otherwise, the original tile size is preserved. +// Note - packed-type-dim and mixed-tile-size should always match! +static SmallVector<OpFoldResult> +getNewMixedTileSizes(PatternRewriter &rewriter, Type newPackedTy, + SmallVector<OpFoldResult> mixedTiles) { + SmallVector<OpFoldResult> newMixedTileSizes; + for (auto it : llvm::zip(cast<ShapedType>(newPackedTy) + .getShape() + .take_back(mixedTiles.size()), + mixedTiles)) { + int64_t shape = std::get<0>(it); + if (shape == ShapedType::kDynamic) { + newMixedTileSizes.push_back(std::get<1>(it)); + continue; + } + + // If the current result dim is static, update the dynamic mixed-size + // (provided the original value is dynamic). + OpFoldResult tile = std::get<1>(it); + if (Attribute attr = llvm::dyn_cast_if_present<Attribute>(tile)) { + // Already a constant + newMixedTileSizes.push_back(tile); + } else { + assert(getConstantIntValue(tile).value() == shape && + "tile size and dim size don't match!"); + newMixedTileSizes.push_back( + (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); + } + } + + return newMixedTileSizes; +} + /// Folds a tensor.cast op into a consuming tensor::PackOp op if the /// `tensor.cast` has source that is more static than the consuming op. /// @@ -4821,31 +4859,13 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); // Get the updated mixed-tile-sizes attribute. - SmallVector<OpFoldResult> newMixedTileSizes; - for (auto it : llvm::zip(cast<ShapedType>(newResultTypes[0]) - .getShape() - .take_back(op.getMixedTiles().size()), - op.getMixedTiles())) { - int64_t shape = std::get<0>(it); - if (shape == ShapedType::kDynamic) { - newMixedTileSizes.push_back(std::get<1>(it)); - continue; - } - - if (Attribute attr = - llvm::dyn_cast_if_present<Attribute>(std::get<1>(it))) { - // Already a constant - newMixedTileSizes.push_back(std::get<1>(it)); - } else { - int64_t tileSize = getConstantIntValue(std::get<1>(it)).value(); - assert(tileSize == shape && "tile size and dim size don't match!"); - (void)tileSize; - newMixedTileSizes.push_back( - (rewriter.getIntegerAttr(rewriter.getIndexType(), shape))); - } - } + SmallVector<OpFoldResult> newMixedTileSizes = + getNewMixedTileSizes(rewriter, newResultTypes[0], op.getMixedTiles()); // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. PackOp newOp = rewriter.create<PackOp>( op.getLoc(), newOperands[0], newOperands[1], op.getInnerDimsPos(), newMixedTileSizes, op.getPaddingValue(), op.getOuterDimsPerm()); @@ -4865,6 +4885,59 @@ struct FoldTensorCastPackOp : public OpRewritePattern<PackOp> { } }; +/// Folds a tensor.cast op into a consuming tensor::UnPackOp op if the +/// `tensor.cast` has source that is more static than the consuming op. +/// +/// Example: +/// ```mlir +/// %1 = tensor.cast %0 : tensor<1x1x8x1xi32> to tensor<1x1x?x1xi32> +/// %2 = tensor.unpack %1 ... : tensor<1x1x?x1xi32> -> tensor<7x?xi32> +/// ``` +/// +/// folds into: +/// +/// ```mlir +/// %2 = tensor.unpack %0 ... tensor<1x1x8x1xi32> -> tensor<7x?xi32> +/// ``` +struct FoldTensorCastUnPackOp : public OpRewritePattern<UnPackOp> { + using OpRewritePattern<UnPackOp>::OpRewritePattern; + + LogicalResult matchAndRewrite(UnPackOp op, + PatternRewriter &rewriter) const override { + if (!foldTensorCastPrecondition(op)) + return failure(); + + SmallVector<Type> newResultTypes(op->getResultTypes()); + SmallVector<Value> newOperands = getNewOperands(op, newResultTypes); + Value sourceTensor = newOperands[0]; + + // Get the updated mixed-tile-sizes attribute. + SmallVector<OpFoldResult> newMixedTileSizes = getNewMixedTileSizes( + rewriter, sourceTensor.getType(), op.getMixedTiles()); + + // Clone op. + // TODO: Strictly speaking, discardable attributes should be _discarded_ at + // this point. However, in practice, we use them for things that we'd like + // to preserve. Implement a better abstraction. + UnPackOp newOp = rewriter.create<UnPackOp>( + op.getLoc(), sourceTensor, newOperands[1], op.getInnerDimsPos(), + newMixedTileSizes, op.getOuterDimsPerm()); + newOp->setDiscardableAttrs(op->getDiscardableAttrDictionary()); + + // Replace op. + Value oldResult = op.getResult(); + Value newResult = newOp.getResult(); + Value replacement = (newResult.getType() != oldResult.getType()) + ? rewriter.create<tensor::CastOp>( + op->getLoc(), oldResult.getType(), newResult) + : newResult; + + rewriter.replaceOp(op, {replacement}); + + return success(); + } +}; + /// Folds a tensor.cast op into a consuming DestinationStyleOpInterface op if /// the `tensor.cast` has source that is more static than the consuming op. /// @@ -4890,7 +4963,8 @@ struct FoldTensorCastProducerOp PatternRewriter &rewriter) const override { // Reject tensor::PackOp - there's dedicated pattern for that instead. - if (!foldTensorCastPrecondition(op) || dyn_cast<tensor::PackOp>(*op)) + if (!foldTensorCastPrecondition(op) || + isa<tensor::PackOp, tensor::UnPackOp>(*op)) return failure(); SmallVector<Type> newResultTypes(op->getResultTypes()); @@ -4923,6 +4997,7 @@ struct FoldTensorCastProducerOp void TensorDialect::getCanonicalizationPatterns( RewritePatternSet &results) const { results.add<FoldTensorCastPackOp>(getContext()); + results.add<FoldTensorCastUnPackOp>(getContext()); results.add<FoldTensorCastProducerOp>(getContext()); } |
