diff options
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp | 46 |
1 files changed, 26 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp index d1e6acef324f..fc1b221b4f03 100644 --- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp +++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp @@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) { // If there is no preceding definition, the tensor contents are // undefined. - if (findDefinitionsCached(opResult).empty()) + if (opResult.getUses().empty()) + continue; + // It does not really matter which use to take to search about + // the value's definitions. + OpOperand *opOperand = &(*opResult.getUses().begin()); + if (findDefinitionsCached(opOperand).empty()) for (OpOperand &use : opResult.getUses()) undefinedTensorUses.insert(&use); } @@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite, /// indexing. I.e., the tensor types do not change along the use-def chain, /// apart from static <-> dynamic dim casts. static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, - Value start, Value other) { + OpOperand *start, + Value other) { TraversalConfig config; config.followEquivalentOnly = true; config.alwaysIncludeLeaves = false; @@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state, .empty(); } -/// Return "true" if `value` is originating from a subset that is equivalent to -/// the subset that `subsetOp` inserts into. -static bool matchesInsertDestination(const AnalysisState &state, Value value, +/// Return "true" if the given operand's value is originating from a subset +/// that is equivalent to the subset that `subsetOp` inserts into. +static bool matchesInsertDestination(const AnalysisState &state, + OpOperand *opOperand, SubsetInsertionOpInterface subsetOp) { auto matchingSubset = [&](Value val) { if (auto opResult = dyn_cast<OpResult>(val)) @@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value, // There may be multiple leaves at which the reverse SSA use-def chain lookup // terminates. All of them must be equivalent subsets. SetVector<Value> backwardSlice = - state.findValueInReverseUseDefChain(value, matchingSubset); + state.findValueInReverseUseDefChain(opOperand, matchingSubset); return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset)); } @@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, // {inplace= [true] } if (uRead == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uConflictingWrite->get(), subsetOp)) + matchesInsertDestination(state, uConflictingWrite, subsetOp)) // Case 1: The main insight is that InsertSliceOp reads only part of // the destination tensor. The overwritten area is not read. If // uConflictingWrite writes into exactly the memory location that is @@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uRead == &subsetOp.getSourceOperand() && uConflictingWrite == &subsetOp.getDestinationOperand() && - matchesInsertDestination(state, uRead->get(), subsetOp)) + matchesInsertDestination(state, uRead, subsetOp)) // Case 2: The read of the source tensor and the write to the dest // tensor via an InsertSliceOp is not a conflict if the read is // reading exactly that part of an equivalent tensor that the @@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead, if (uConflictingWrite == &subsetOp.getDestinationOperand() && state.areEquivalentBufferizedValues( uRead->get(), subsetOp.getSourceOperand().get()) && - matchesInsertDestination(state, subsetOp.getSourceOperand().get(), - subsetOp)) + matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp)) return true; return false; @@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // even though that op just bufferizes to an allocation but does define // the contents of the buffer. SetVector<Value> definitionsOrLeaves = - state.findValueInReverseUseDefChain( - uConflictingWrite->get(), - [&](Value v) { return state.bufferizesToMemoryWrite(v); }); + state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) { + return state.bufferizesToMemoryWrite(v); + }); assert(!definitionsOrLeaves.empty() && "expected at least one definition or leaf"); @@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, // In the above example, if uRead is the OpOperand of reading_op, the // definition is %0. Note that operations that create an alias but do not // bufferize to a memory write (such as ExtractSliceOp) are skipped. - const SetVector<Value> &definitions = - state.findDefinitionsCached(uRead->get()); + const SetVector<Value> &definitions = state.findDefinitionsCached(uRead); if (definitions.empty()) { // Fast path: No conflict if there are no definitions. LLVM_DEBUG(llvm::dbgs() @@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead, if (bufferizableOp.bufferizesToElementwiseAccess( state, {uRead, uConflictingWrite})) { if (hasEquivalentValueInReverseUseDefChain( - state, uRead->get(), uConflictingWrite->get()) || + state, uRead, uConflictingWrite->get()) || hasEquivalentValueInReverseUseDefChain( - state, uConflictingWrite->get(), uRead->get())) { + state, uConflictingWrite, uRead->get())) { LLVM_DEBUG( llvm::dbgs() << " no conflict: op bufferizes to element-wise access\n"); @@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand, // Bufferization analyses. //===----------------------------------------------------------------------===// -// Find the values that define the contents of the given value. +// Find the values that define the contents of the given operand's value. const llvm::SetVector<Value> & -OneShotAnalysisState::findDefinitionsCached(Value value) { +OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) { + Value value = opOperand->get(); if (!cachedDefinitions.count(value)) - cachedDefinitions[value] = findDefinitions(value); + cachedDefinitions[value] = findDefinitions(opOperand); return cachedDefinitions[value]; } |
