summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp')
-rw-r--r--mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp46
1 files changed, 26 insertions, 20 deletions
diff --git a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
index d1e6acef324f..fc1b221b4f03 100644
--- a/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
+++ b/mlir/lib/Dialect/Bufferization/Transforms/OneShotAnalysis.cpp
@@ -196,7 +196,12 @@ void OneShotAnalysisState::gatherUndefinedTensorUses(Operation *op) {
// If there is no preceding definition, the tensor contents are
// undefined.
- if (findDefinitionsCached(opResult).empty())
+ if (opResult.getUses().empty())
+ continue;
+ // It does not really matter which use to take to search about
+ // the value's definitions.
+ OpOperand *opOperand = &(*opResult.getUses().begin());
+ if (findDefinitionsCached(opOperand).empty())
for (OpOperand &use : opResult.getUses())
undefinedTensorUses.insert(&use);
}
@@ -464,7 +469,8 @@ static void annotateConflict(OpOperand *uRead, OpOperand *uConflictingWrite,
/// indexing. I.e., the tensor types do not change along the use-def chain,
/// apart from static <-> dynamic dim casts.
static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
- Value start, Value other) {
+ OpOperand *start,
+ Value other) {
TraversalConfig config;
config.followEquivalentOnly = true;
config.alwaysIncludeLeaves = false;
@@ -475,9 +481,10 @@ static bool hasEquivalentValueInReverseUseDefChain(AnalysisState &state,
.empty();
}
-/// Return "true" if `value` is originating from a subset that is equivalent to
-/// the subset that `subsetOp` inserts into.
-static bool matchesInsertDestination(const AnalysisState &state, Value value,
+/// Return "true" if the given operand's value is originating from a subset
+/// that is equivalent to the subset that `subsetOp` inserts into.
+static bool matchesInsertDestination(const AnalysisState &state,
+ OpOperand *opOperand,
SubsetInsertionOpInterface subsetOp) {
auto matchingSubset = [&](Value val) {
if (auto opResult = dyn_cast<OpResult>(val))
@@ -490,7 +497,7 @@ static bool matchesInsertDestination(const AnalysisState &state, Value value,
// There may be multiple leaves at which the reverse SSA use-def chain lookup
// terminates. All of them must be equivalent subsets.
SetVector<Value> backwardSlice =
- state.findValueInReverseUseDefChain(value, matchingSubset);
+ state.findValueInReverseUseDefChain(opOperand, matchingSubset);
return static_cast<bool>(llvm::all_of(backwardSlice, matchingSubset));
}
@@ -516,7 +523,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
// {inplace= [true] }
if (uRead == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uConflictingWrite->get(), subsetOp))
+ matchesInsertDestination(state, uConflictingWrite, subsetOp))
// Case 1: The main insight is that InsertSliceOp reads only part of
// the destination tensor. The overwritten area is not read. If
// uConflictingWrite writes into exactly the memory location that is
@@ -533,7 +540,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uRead == &subsetOp.getSourceOperand() &&
uConflictingWrite == &subsetOp.getDestinationOperand() &&
- matchesInsertDestination(state, uRead->get(), subsetOp))
+ matchesInsertDestination(state, uRead, subsetOp))
// Case 2: The read of the source tensor and the write to the dest
// tensor via an InsertSliceOp is not a conflict if the read is
// reading exactly that part of an equivalent tensor that the
@@ -567,8 +574,7 @@ static bool areNonConflictingSubsets(OpOperand *uRead,
if (uConflictingWrite == &subsetOp.getDestinationOperand() &&
state.areEquivalentBufferizedValues(
uRead->get(), subsetOp.getSourceOperand().get()) &&
- matchesInsertDestination(state, subsetOp.getSourceOperand().get(),
- subsetOp))
+ matchesInsertDestination(state, &subsetOp.getSourceOperand(), subsetOp))
return true;
return false;
@@ -600,9 +606,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// even though that op just bufferizes to an allocation but does define
// the contents of the buffer.
SetVector<Value> definitionsOrLeaves =
- state.findValueInReverseUseDefChain(
- uConflictingWrite->get(),
- [&](Value v) { return state.bufferizesToMemoryWrite(v); });
+ state.findValueInReverseUseDefChain(uConflictingWrite, [&](Value v) {
+ return state.bufferizesToMemoryWrite(v);
+ });
assert(!definitionsOrLeaves.empty() &&
"expected at least one definition or leaf");
@@ -641,8 +647,7 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
// In the above example, if uRead is the OpOperand of reading_op, the
// definition is %0. Note that operations that create an alias but do not
// bufferize to a memory write (such as ExtractSliceOp) are skipped.
- const SetVector<Value> &definitions =
- state.findDefinitionsCached(uRead->get());
+ const SetVector<Value> &definitions = state.findDefinitionsCached(uRead);
if (definitions.empty()) {
// Fast path: No conflict if there are no definitions.
LLVM_DEBUG(llvm::dbgs()
@@ -714,9 +719,9 @@ hasReadAfterWriteInterference(const DenseSet<OpOperand *> &usesRead,
if (bufferizableOp.bufferizesToElementwiseAccess(
state, {uRead, uConflictingWrite})) {
if (hasEquivalentValueInReverseUseDefChain(
- state, uRead->get(), uConflictingWrite->get()) ||
+ state, uRead, uConflictingWrite->get()) ||
hasEquivalentValueInReverseUseDefChain(
- state, uConflictingWrite->get(), uRead->get())) {
+ state, uConflictingWrite, uRead->get())) {
LLVM_DEBUG(
llvm::dbgs()
<< " no conflict: op bufferizes to element-wise access\n");
@@ -965,11 +970,12 @@ wouldCreateWriteToNonWritableBuffer(OpOperand &operand,
// Bufferization analyses.
//===----------------------------------------------------------------------===//
-// Find the values that define the contents of the given value.
+// Find the values that define the contents of the given operand's value.
const llvm::SetVector<Value> &
-OneShotAnalysisState::findDefinitionsCached(Value value) {
+OneShotAnalysisState::findDefinitionsCached(OpOperand *opOperand) {
+ Value value = opOperand->get();
if (!cachedDefinitions.count(value))
- cachedDefinitions[value] = findDefinitions(value);
+ cachedDefinitions[value] = findDefinitions(opOperand);
return cachedDefinitions[value];
}