diff options
| author | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:49:54 +0900 |
|---|---|---|
| committer | NAKAMURA Takumi <geek4civic@gmail.com> | 2025-01-09 18:49:54 +0900 |
| commit | e2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch) | |
| tree | ae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp | |
| parent | fa04eb4af95c1ca7377279728cb004bcd2324d01 (diff) | |
| parent | bdcf47e4bcb92889665825654bb80a8bbe30379e (diff) | |
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp | 53 |
1 files changed, 30 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index abc0635a2cdf..2c4e362101f8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, return nullptr; } +Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, + Operation *user) { + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // All values that are needed to create the replacement op. + SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction(); + // Find a suitable insertion point. If no suitable insertion point + // for the replacement can be found, return an empty value to skip + // this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, user, neededValues); + if (!insertionPoint) + return {}; + + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + return replacement; +} + LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn) { OpBuilder::InsertionGuard g(rewriter); llvm::DenseSet<OpOperand *> visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { @@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( if (!state.isInPlace(source)) return WalkResult::skip(); - // All values that are needed to create the replacement op. - SmallVector<Value> neededValues = - op.getValuesNeededToBuildSubsetExtraction(); - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. @@ -124,35 +143,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( // %3 = tensor.insert_slice %2 into ... config.followSameTypeOrCastsOnly = true; SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( - source.get(), /*condition=*/ + &source, /*condition=*/ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config, &visitedOpOperands); for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - + auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>(); + assert(emptyTensorOp && "expected tensor.empty op"); // Find the use to be replaced from the use-def chain. auto iter = llvm::find_if( visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { return llvm::count(emptyTensorOp->getUses(), *opOperand); }); - // This could be achieved when a use of `emptyTensorOp` is being - // consumed by `SubsetInsertionOpInterface`'s source directly. - if (iter == visitedOpOperands.end()) - continue; + + assert(iter != visitedOpOperands.end() && "could not find use"); OpOperand *useToBeReplaced = *iter; Operation *user = useToBeReplaced->getOwner(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); if (!replacement) continue; if (emptyTensorOp == replacement.getDefiningOp()) |
