summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp')
-rw-r--r--mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp6
1 files changed, 5 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
index 0c8e431d8c99..c612a52aa8d5 100644
--- a/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
+++ b/mlir/lib/Dialect/SparseTensor/Transforms/Sparsification.cpp
@@ -1071,7 +1071,11 @@ static bool getAllTidLvlsInLatPoints(
}
// If we just need to one loop conditions and the conditions is not imposed on
// non-unique level, the loop can be generated by a for loop.
- return numloopCond == 1 && !hasNonUnique;
+ // Or, if we are generating sparse-iterator-based loops, we always generate
+ // `sparse_tensor.iterate` regardless whether the level is unique or not.
+ return numloopCond == 1 &&
+ (!hasNonUnique || env.options().sparseEmitStrategy ==
+ SparseEmitStrategy::kSparseIterator);
}
/// Starts a loop sequence at given level. Returns true if