diff options
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7c019e7d25bf..8b5e950733a2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -341,13 +341,18 @@ private: /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the -/// distributed dimensions. The number of results should be equal to the number +/// distributed dimensions. If the number of results is zero there is no +/// distribution (i.e. original type is returned). +/// Otherwise, The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { + // If the map has zero results, return the original type. + if (map.getNumResults() == 0) + return originalType; SmallVector<int64_t> targetShape(originalType.getShape()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); |
