diff options
| author | Simone Pellegrini <simone.pellegrini@arm.com> | 2025-11-19 15:52:27 +0100 |
|---|---|---|
| committer | GitHub <noreply@github.com> | 2025-11-19 14:52:27 +0000 |
| commit | 71e3de8a7f1c0fc71302ac84c826f34fa324ee1c (patch) | |
| tree | 78983e69525eadf323b29edf01c2932f21bbbd20 /mlir | |
| parent | c62fc065b4c10370c1aa68cad6f5fa980b640136 (diff) | |
[mlir][vector] Missing indices on vectorization of 1-d reduction to 1-ranked memref (#166959)
Vectorization of a 1-d reduction where the output variable is a 1-ranked
memref can generate an invalid `vector.transfer_write` with no indices
for the memref, e.g.:
vector.transfer_write"(%vec, %buff) <{...}> : (vector<f32>,
memref<1xf32>) -> ()
This patch solves the problem by providing the expected amount of
indices (i.e. matching the rank of the memref).
Diffstat (limited to 'mlir')
| -rw-r--r-- | mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp | 8 | ||||
| -rw-r--r-- | mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir | 68 |
2 files changed, 64 insertions, 12 deletions
diff --git a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp index dcf84c46949f..bb3bccdae0e1 100644 --- a/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp +++ b/mlir/lib/Dialect/Linalg/Transforms/Vectorization.cpp @@ -746,12 +746,12 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, auto vectorType = state.getCanonicalVecType( getElementTypeOrSelf(outputOperand->get().getType()), vectorTypeMap); + SmallVector<Value> indices(linalgOp.getRank(outputOperand), + arith::ConstantIndexOp::create(rewriter, loc, 0)); + Operation *write; if (vectorType.getRank() > 0) { AffineMap writeMap = inversePermutation(reindexIndexingMap(opOperandMap)); - SmallVector<Value> indices( - linalgOp.getRank(outputOperand), - arith::ConstantIndexOp::create(rewriter, loc, 0)); value = broadcastIfNeeded(rewriter, value, vectorType); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create( @@ -762,7 +762,7 @@ static Value buildVectorWrite(RewriterBase &rewriter, Value value, value = vector::BroadcastOp::create(rewriter, loc, vectorType, value); assert(value.getType() == vectorType && "Incorrect type"); write = vector::TransferWriteOp::create(rewriter, loc, value, - outputOperand->get(), ValueRange{}); + outputOperand->get(), indices); } write = state.maskOperation(rewriter, write, linalgOp, opOperandMap); diff --git a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir index 9a14ab7d38d3..95959fcf085f 100644 --- a/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir +++ b/mlir/test/Dialect/Linalg/vectorization/linalg-ops-with-patterns.mlir @@ -1481,23 +1481,23 @@ module attributes {transform.with_named_sequence} { // ----- -// CHECK-LABEL: func @reduce_1d( -// CHECK-SAME: %[[A:.*]]: tensor<32xf32> -func.func @reduce_1d(%arg0: tensor<32xf32>) -> tensor<f32> { +// CHECK-LABEL: func @reduce_to_rank_0( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_0(%arg0: tensor<32xf32>) -> tensor<f32> { // CHECK-DAG: %[[F0:.*]] = arith.constant 0.000000e+00 : f32 // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index %f0 = arith.constant 0.000000e+00 : f32 - // CHECK: %[[init:.*]] = tensor.empty() : tensor<f32> + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<f32> %0 = tensor.empty() : tensor<f32> %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<f32>) -> tensor<f32> - // CHECK: %[[r:.*]] = vector.transfer_read %[[A]][%[[C0]]] + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] // CHECK-SAME: : tensor<32xf32>, vector<32xf32> - // CHECK: %[[red:.*]] = vector.multi_reduction <add>, %[[r]], %[[F0]] [0] + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[F0]] [0] // CHECK-SAME: : vector<32xf32> to f32 - // CHECK: %[[red_v1:.*]] = vector.broadcast %[[red]] : f32 to vector<f32> - // CHECK: %[[res:.*]] = vector.transfer_write %[[red_v1]], %[[init]][] + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: %[[RES:.*]] = vector.transfer_write %[[RED_V1]], %[[INIT]][] // CHECK-SAME: : vector<f32>, tensor<f32> %2 = linalg.generic { indexing_maps = [affine_map<(d0) -> (d0)>, @@ -1525,6 +1525,58 @@ module attributes {transform.with_named_sequence} { // ----- +// CHECK-LABEL: func @reduce_to_rank_1( +// CHECK-SAME: %[[SRC:.*]]: tensor<32xf32> +func.func @reduce_to_rank_1(%arg0: tensor<32xf32>) -> tensor<1xf32> { + // CHECK-DAG: %[[C0:.*]] = arith.constant 0 : index + // CHECK-DAG: %[[F0:.*]] = arith.constant dense<0.000000e+00> : vector<1xf32> + %f0 = arith.constant 0.000000e+00 : f32 + + // CHECK: %[[INIT:.*]] = tensor.empty() : tensor<1xf32> + %0 = tensor.empty() : tensor<1xf32> + + // CHECK: %[[INIT_ZERO:.*]] = vector.transfer_write %[[F0]], %[[INIT]][%[[C0]]] + // CHECK-SAME: : vector<1xf32>, tensor<1xf32> + %1 = linalg.fill ins(%f0 : f32) outs(%0 : tensor<1xf32>) -> tensor<1xf32> + + // CHECK: %[[R:.*]] = vector.transfer_read %[[SRC]][%[[C0]]] + // CHECK-SAME: : tensor<32xf32>, vector<32xf32> + // CHECK: %[[INIT_ZERO_VEC:.*]] = vector.transfer_read %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : tensor<1xf32>, vector<f32> + // CHECK: %[[INIT_ZERO_SCL:.*]] = vector.extract %[[INIT_ZERO_VEC]][] + // CHECK-SAME: : f32 from vector<f32> + // CHECK: %[[RED:.*]] = vector.multi_reduction <add>, %[[R]], %[[INIT_ZERO_SCL]] [0] + // CHECK-SAME: : vector<32xf32> to f32 + // CHECK: %[[RED_V1:.*]] = vector.broadcast %[[RED]] : f32 to vector<f32> + // CHECK: vector.transfer_write %[[RED_V1]], %[[INIT_ZERO]][%[[C0]]] + // CHECK-SAME: : vector<f32>, tensor<1xf32> + + %2 = linalg.generic { + indexing_maps = [affine_map<(d0) -> (d0)>, + affine_map<(d0) -> (0)>], + iterator_types = ["reduction"]} + ins(%arg0 : tensor<32xf32>) + outs(%1 : tensor<1xf32>) { + ^bb0(%a: f32, %b: f32): + %3 = arith.addf %a, %b : f32 + linalg.yield %3 : f32 + } -> tensor<1xf32> + + return %2 : tensor<1xf32> +} + +module attributes {transform.with_named_sequence} { + transform.named_sequence @__transform_main(%arg1: !transform.any_op {transform.readonly}) { + %0 = transform.structured.match ops{["linalg.generic"]} in %arg1 : (!transform.any_op) -> !transform.any_op + %1 = transform.get_parent_op %0 {isolated_from_above} : (!transform.any_op) -> !transform.any_op + %2 = transform.structured.vectorize_children_and_apply_patterns %1 : (!transform.any_op) -> !transform.any_op + transform.yield + } +} + + +// ----- + // This test checks that vectorization does not occur when an input indexing map // is not a projected permutation. In the future, this can be converted to a // positive test when support is added. |
