diff options
Diffstat (limited to 'mlir/lib/Dialect/Tensor/IR/TensorOps.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Tensor/IR/TensorOps.cpp | 26 |
1 files changed, 15 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp index 659eabd2e938..4d6c5965c4fc 100644 --- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp +++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp @@ -4332,21 +4332,25 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) { rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource()); } Value dest = packOp.getDest(); - if (destShape != packOp.getDestType().getShape()) { + RankedTensorType originalResultType = packOp.getDestType(); + bool needUpdateDestType = (destShape != originalResultType.getShape()); + if (needUpdateDestType) { auto newDestType = packOp.getDestType().clone(destShape); dest = rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest()); } - auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp)); - Value res = clonedPackOp.getResult(); - rewriter.startOpModification(clonedPackOp); - clonedPackOp.getSourceMutable().assign(source); - clonedPackOp.getDestMutable().assign(dest); - res.setType(dest.getType()); - rewriter.finalizeOpModification(clonedPackOp); - - rewriter.replaceOpWithNewOp<tensor::CastOp>( - packOp, packOp.getResult().getType(), clonedPackOp); + rewriter.modifyOpInPlace(packOp, [&] { + packOp.getSourceMutable().assign(source); + packOp.getDestMutable().assign(dest); + packOp.getResult().setType(cast<RankedTensorType>(dest.getType())); + }); + // Insert a cast if needed + if (needUpdateDestType) { + rewriter.setInsertionPointAfter(packOp); + auto castOp = + rewriter.create<tensor::CastOp>(loc, originalResultType, packOp); + rewriter.replaceAllUsesExcept(packOp, castOp, castOp); + } return success(); } |
