diff options
| author | Florian Mayer <fmayer@google.com> | 2025-10-22 10:55:10 -0700 |
|---|---|---|
| committer | Florian Mayer <fmayer@google.com> | 2025-10-22 10:55:10 -0700 |
| commit | f5f8398d7fe18a968f5873518e87d5fdd8269359 (patch) | |
| tree | 347dff286c3b48b2336fb7a425adfceebd478116 /mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | |
| parent | 73edaec4a6cd1212f9ae819c413d2cf58216d3b1 (diff) | |
| parent | a0abc0af0a0a90878822f8107d70dad6f7cdfc26 (diff) | |
Created using spr 1.3.7
Diffstat (limited to 'mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp')
| -rw-r--r-- | mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp | 7 |
1 files changed, 6 insertions, 1 deletions
diff --git a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp index 7c019e7d25bf..8b5e950733a2 100644 --- a/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp +++ b/mlir/lib/Dialect/Vector/Transforms/VectorDistribute.cpp @@ -341,13 +341,18 @@ private: /// Return the distributed vector type based on the original type and the /// distribution map. The map is expected to have a dimension equal to the /// original type rank and should be a projection where the results are the -/// distributed dimensions. The number of results should be equal to the number +/// distributed dimensions. If the number of results is zero there is no +/// distribution (i.e. original type is returned). +/// Otherwise, The number of results should be equal to the number /// of warp sizes which is currently limited to 1. /// Example: For a vector<16x32x64> distributed with a map(d0, d1, d2) -> (d1) /// and a warp size of 16 would distribute the second dimension (associated to /// d1) and return vector<16x2x64> static VectorType getDistributedType(VectorType originalType, AffineMap map, int64_t warpSize) { + // If the map has zero results, return the original type. + if (map.getNumResults() == 0) + return originalType; SmallVector<int64_t> targetShape(originalType.getShape()); for (unsigned i = 0, e = map.getNumResults(); i < e; i++) { unsigned position = map.getDimPosition(i); |
