summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
commite2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch)
treeae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
parentfa04eb4af95c1ca7377279728cb004bcd2324d01 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp')
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp53
1 files changed, 30 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
index abc0635a2cdf..2c4e362101f8 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp
@@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user,
return nullptr;
}
+Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter,
+ SubsetInsertionOpInterface op,
+ tensor::EmptyOp emptyTensorOp,
+ Operation *user) {
+
+ mlir::OpBuilder::InsertionGuard guard(rewriter);
+ // All values that are needed to create the replacement op.
+ SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction();
+ // Find a suitable insertion point. If no suitable insertion point
+ // for the replacement can be found, return an empty value to skip
+ // this replacement.
+ Operation *insertionPoint =
+ findValidInsertionPoint(emptyTensorOp, user, neededValues);
+ if (!insertionPoint)
+ return {};
+
+ rewriter.setInsertionPoint(insertionPoint);
+ Value replacement =
+ op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ return replacement;
+}
+
LogicalResult mlir::bufferization::eliminateEmptyTensors(
- RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) {
+ RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state,
+ ControlBuildSubsetExtractionFn subsetsExtractionFn) {
OpBuilder::InsertionGuard g(rewriter);
llvm::DenseSet<OpOperand *> visitedOpOperands;
op->walk([&](SubsetInsertionOpInterface op) {
@@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
if (!state.isInPlace(source))
return WalkResult::skip();
- // All values that are needed to create the replacement op.
- SmallVector<Value> neededValues =
- op.getValuesNeededToBuildSubsetExtraction();
-
// Find tensor.empty ops on the reverse SSA use-def chain. Only follow
// equivalent tensors. I.e., stop when there are ops such as extract_slice
// on the path.
@@ -124,35 +143,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors(
// %3 = tensor.insert_slice %2 into ...
config.followSameTypeOrCastsOnly = true;
SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain(
- source.get(), /*condition=*/
+ &source, /*condition=*/
[&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config,
&visitedOpOperands);
for (Value v : emptyTensors) {
- Operation *emptyTensorOp = v.getDefiningOp();
-
+ auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>();
+ assert(emptyTensorOp && "expected tensor.empty op");
// Find the use to be replaced from the use-def chain.
auto iter = llvm::find_if(
visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) {
return llvm::count(emptyTensorOp->getUses(), *opOperand);
});
- // This could be achieved when a use of `emptyTensorOp` is being
- // consumed by `SubsetInsertionOpInterface`'s source directly.
- if (iter == visitedOpOperands.end())
- continue;
+
+ assert(iter != visitedOpOperands.end() && "could not find use");
OpOperand *useToBeReplaced = *iter;
Operation *user = useToBeReplaced->getOwner();
-
- // Find a suitable insertion point. If no suitable insertion point for
- // the replacement can be found, skip this replacement.
- Operation *insertionPoint =
- findValidInsertionPoint(emptyTensorOp, user, neededValues);
- if (!insertionPoint)
- continue;
-
- rewriter.setInsertionPoint(insertionPoint);
- Value replacement =
- op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc());
+ auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user);
if (!replacement)
continue;
if (emptyTensorOp == replacement.getDefiningOp())