summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Tensor/IR/TensorOps.cpp')
-rw-r--r--mlir/lib/Dialect/Tensor/IR/TensorOps.cpp26
1 files changed, 15 insertions, 11 deletions
diff --git a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
index 659eabd2e938..4d6c5965c4fc 100644
--- a/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
+++ b/mlir/lib/Dialect/Tensor/IR/TensorOps.cpp
@@ -4332,21 +4332,25 @@ LogicalResult PackOp::canonicalize(PackOp packOp, PatternRewriter &rewriter) {
rewriter.create<tensor::CastOp>(loc, newSrcType, packOp.getSource());
}
Value dest = packOp.getDest();
- if (destShape != packOp.getDestType().getShape()) {
+ RankedTensorType originalResultType = packOp.getDestType();
+ bool needUpdateDestType = (destShape != originalResultType.getShape());
+ if (needUpdateDestType) {
auto newDestType = packOp.getDestType().clone(destShape);
dest =
rewriter.create<tensor::CastOp>(loc, newDestType, packOp.getDest());
}
- auto clonedPackOp = cast<PackOp>(rewriter.clone(*packOp));
- Value res = clonedPackOp.getResult();
- rewriter.startOpModification(clonedPackOp);
- clonedPackOp.getSourceMutable().assign(source);
- clonedPackOp.getDestMutable().assign(dest);
- res.setType(dest.getType());
- rewriter.finalizeOpModification(clonedPackOp);
-
- rewriter.replaceOpWithNewOp<tensor::CastOp>(
- packOp, packOp.getResult().getType(), clonedPackOp);
+ rewriter.modifyOpInPlace(packOp, [&] {
+ packOp.getSourceMutable().assign(source);
+ packOp.getDestMutable().assign(dest);
+ packOp.getResult().setType(cast<RankedTensorType>(dest.getType()));
+ });
+ // Insert a cast if needed
+ if (needUpdateDestType) {
+ rewriter.setInsertionPointAfter(packOp);
+ auto castOp =
+ rewriter.create<tensor::CastOp>(loc, originalResultType, packOp);
+ rewriter.replaceAllUsesExcept(packOp, castOp, castOp);
+ }
return success();
}