diff options
| author | Quinn Dawkins <quinn.dawkins@gmail.com> | 2024-02-28 00:11:28 -0500 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2024-02-28 00:11:28 -0500 |
| commit | c2b952926fe8707527cf1b8bab211dc4c7ab9aee (patch) | |
| tree | 65bb008dd9939b10bb4927d8e4d976b24e0123c5 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | |
| parent | 87c0260f45e5a02cb07722d089dae3f0f84c7b3d (diff) | |
[mlir][vector] Fix n-d transfer write distribution (#83215)
Currently n-d transfer write distribution can be inconsistent with
distribution of reductions if a value has multiple users, one of which
is a transfer_write with a non-standard distribution map, and the other
of which is a vector.reduction.
We may want to consider removing the distribution map functionality in
the future for this reason.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
| -rw-r--r-- | mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp | 6 |
1 files changed, 2 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp index 178a58e796b2..915f713f7047 100644 --- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp +++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp @@ -630,15 +630,13 @@ struct TestVectorDistribution }); MLIRContext *ctx = &getContext(); auto distributionFn = [](Value val) { - // Create a map (d0, d1) -> (d1) to distribute along the inner - // dimension. Once we support n-d distribution we can add more - // complex cases. + // Create an identity dim map of the same rank as the vector. VectorType vecType = dyn_cast<VectorType>(val.getType()); int64_t vecRank = vecType ? vecType.getRank() : 0; OpBuilder builder(val.getContext()); if (vecRank == 0) return AffineMap::get(val.getContext()); - return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1)); + return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext()); }; auto shuffleFn = [](Location loc, OpBuilder &builder, Value val, Value srcIdx, int64_t warpSz) { |
