diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp | 53 |
1 files changed, 30 insertions, 23 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp index abc0635a2cdf..2c4e362101f8 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/EmptyTensorElimination.cpp @@ -93,8 +93,31 @@ findValidInsertionPoint(Operation *emptyTensorOp, Operation *user, return nullptr; } +Value mlir::bufferization::buildSubsetExtraction(RewriterBase &rewriter, + SubsetInsertionOpInterface op, + tensor::EmptyOp emptyTensorOp, + Operation *user) { + + mlir::OpBuilder::InsertionGuard guard(rewriter); + // All values that are needed to create the replacement op. + SmallVector<Value> neededValues = op.getValuesNeededToBuildSubsetExtraction(); + // Find a suitable insertion point. If no suitable insertion point + // for the replacement can be found, return an empty value to skip + // this replacement. + Operation *insertionPoint = + findValidInsertionPoint(emptyTensorOp, user, neededValues); + if (!insertionPoint) + return {}; + + rewriter.setInsertionPoint(insertionPoint); + Value replacement = + op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + return replacement; +} + LogicalResult mlir::bufferization::eliminateEmptyTensors( - RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state) { + RewriterBase &rewriter, Operation *op, OneShotAnalysisState &state, + ControlBuildSubsetExtractionFn subsetsExtractionFn) { OpBuilder::InsertionGuard g(rewriter); llvm::DenseSet<OpOperand *> visitedOpOperands; op->walk([&](SubsetInsertionOpInterface op) { @@ -105,10 +128,6 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( if (!state.isInPlace(source)) return WalkResult::skip(); - // All values that are needed to create the replacement op. - SmallVector<Value> neededValues = - op.getValuesNeededToBuildSubsetExtraction(); - // Find tensor.empty ops on the reverse SSA use-def chain. Only follow // equivalent tensors. I.e., stop when there are ops such as extract_slice // on the path. @@ -124,35 +143,23 @@ LogicalResult mlir::bufferization::eliminateEmptyTensors( // %3 = tensor.insert_slice %2 into ... config.followSameTypeOrCastsOnly = true; SetVector<Value> emptyTensors = state.findValueInReverseUseDefChain( - source.get(), /*condition=*/ + &source, /*condition=*/ [&](Value val) { return val.getDefiningOp<tensor::EmptyOp>(); }, config, &visitedOpOperands); for (Value v : emptyTensors) { - Operation *emptyTensorOp = v.getDefiningOp(); - + auto emptyTensorOp = v.getDefiningOp<tensor::EmptyOp>(); + assert(emptyTensorOp && "expected tensor.empty op"); // Find the use to be replaced from the use-def chain. auto iter = llvm::find_if( visitedOpOperands, [&emptyTensorOp](OpOperand *opOperand) { return llvm::count(emptyTensorOp->getUses(), *opOperand); }); - // This could be achieved when a use of `emptyTensorOp` is being - // consumed by `SubsetInsertionOpInterface`'s source directly. - if (iter == visitedOpOperands.end()) - continue; + + assert(iter != visitedOpOperands.end() && "could not find use"); OpOperand *useToBeReplaced = *iter; Operation *user = useToBeReplaced->getOwner(); - - // Find a suitable insertion point. If no suitable insertion point for - // the replacement can be found, skip this replacement. - Operation *insertionPoint = - findValidInsertionPoint(emptyTensorOp, user, neededValues); - if (!insertionPoint) - continue; - - rewriter.setInsertionPoint(insertionPoint); - Value replacement = - op.buildSubsetExtraction(rewriter, emptyTensorOp->getLoc()); + auto replacement = subsetsExtractionFn(rewriter, op, emptyTensorOp, user); if (!replacement) continue; if (emptyTensorOp == replacement.getDefiningOp()) |
