summaryrefslogtreecommitdiff
path: root/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
diff options
context:
space:
mode:
authorQuinn Dawkins <quinn.dawkins@gmail.com>2024-02-28 00:11:28 -0500
committerGitHub <noreply@github.com>2024-02-28 00:11:28 -0500
commitc2b952926fe8707527cf1b8bab211dc4c7ab9aee (patch)
tree65bb008dd9939b10bb4927d8e4d976b24e0123c5 /mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
parent87c0260f45e5a02cb07722d089dae3f0f84c7b3d (diff)
[mlir][vector] Fix n-d transfer write distribution (#83215)
Currently n-d transfer write distribution can be inconsistent with distribution of reductions if a value has multiple users, one of which is a transfer_write with a non-standard distribution map, and the other of which is a vector.reduction. We may want to consider removing the distribution map functionality in the future for this reason.
Diffstat (limited to 'mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp')
-rw-r--r--mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp6
1 files changed, 2 insertions, 4 deletions
diff --git a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
index 178a58e796b2..915f713f7047 100644
--- a/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
+++ b/mlir/test/lib/Dialect/Vector/TestVectorTransforms.cpp
@@ -630,15 +630,13 @@ struct TestVectorDistribution
});
MLIRContext *ctx = &getContext();
auto distributionFn = [](Value val) {
- // Create a map (d0, d1) -> (d1) to distribute along the inner
- // dimension. Once we support n-d distribution we can add more
- // complex cases.
+ // Create an identity dim map of the same rank as the vector.
VectorType vecType = dyn_cast<VectorType>(val.getType());
int64_t vecRank = vecType ? vecType.getRank() : 0;
OpBuilder builder(val.getContext());
if (vecRank == 0)
return AffineMap::get(val.getContext());
- return AffineMap::get(vecRank, 0, builder.getAffineDimExpr(vecRank - 1));
+ return AffineMap::getMultiDimIdentityMap(vecRank, val.getContext());
};
auto shuffleFn = [](Location loc, OpBuilder &builder, Value val,
Value srcIdx, int64_t warpSz) {