summaryrefslogtreecommitdiff
path: root/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp')
-rw-r--r--mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp7
1 files changed, 6 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
index 7c019e7d25bf..8b5e950733a2 100644
--- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
+++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp
@@ -341,13 +341,18 @@ private:
/// Return the distributed vector type based on the original type and the
/// distribution map. The map is expected to have a dimension equal to the
/// original type rank and should be a projection where the results are the
-/// distributed dimensions. The number of results should be equal to the number
+/// distributed dimensions. If the number of results is zero there is no
+/// distribution (i.e. original type is returned).
+/// Otherwise, The number of results should be equal to the number
/// of warp sizes which is currently limited to 1.
/// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1)
/// and a warp size of 16 would distribute the second dimension (associated to
/// d1) and return vector<16x2x64>
static VectorType getDistributedType(VectorType originalType, AffineMap map,
int64_t warpSize) {
+ // If the map has zero results, return the original type.
+ if (map.getNumResults() == 0)
+ return originalType;
SmallVector<int64_t> targetShape(originalType.getShape());
for (unsigned i = 0, e = map.getNumResults(); i < e; i++) {
unsigned position = map.getDimPosition(i);