// RUN: mlir-opt -split-input-file -test-tensor-transform-patterns=test-expand-shape-bubbling %s | FileCheck %s func.func @bubble_parallel_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %expand = tensor.expand_shape %collapse [[0], [1], [2, 3]] output_shape [%s0, %s1, %s2, %s3] : tensor into tensor return %expand : tensor } // CHECK: func @bubble_parallel_reshapes // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK-SAME: %[[S0:.+]]: index, %[[S1:.+]]: index, %[[S2:.+]]: index, %[[S3:.+]]: index // CHECK-DAG: %[[C1:.+]] = arith.constant 1 : index // CHECK-DAG: %[[C2:.+]] = arith.constant 2 : index // CHECK-DAG: %[[DIM1:.+]] = tensor.dim %[[ARG0]], %[[C1]] : tensor // CHECK-DAG: %[[DIM2:.+]] = tensor.dim %[[ARG0]], %[[C2]] : tensor // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[ARG0]] {{\[}}[0], [1], [2], [3, 4]] // CHECK-SAME: output_shape [%[[S0]], %[[DIM1]], %[[DIM2]], %[[S2]], %[[S3]]] : tensor into tensor // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[EXPAND]] {{\[}}[0], [1, 2], [3], [4]] : tensor into tensor // CHECK: return %[[COLLAPSE]] // ----- func.func @no_bubble_full_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { %collapse = tensor.collapse_shape %arg0 [[0], [1, 2], [3]] : tensor into tensor %expand = tensor.expand_shape %collapse [[0], [1, 2], [3]] output_shape [%s0, %s1, %s2, %s3] : tensor into tensor return %expand : tensor } // CHECK: func @no_bubble_full_intersecting_reshapes // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0], [1, 2], [3]] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0], [1, 2], [3]] // CHECK: return %[[EXPAND]] // ----- func.func @no_bubble_partial_intersecting_reshapes(%arg0: tensor, %s0: index, %s1: index, %s2: index, %s3: index) -> tensor { %collapse = tensor.collapse_shape %arg0 [[0, 1, 2], [3]] : tensor into tensor %expand = tensor.expand_shape %collapse [[0, 1], [2, 3]] output_shape [%s0, %s1, %s2, %s3] : tensor into tensor return %expand : tensor } // CHECK: func @no_bubble_partial_intersecting_reshapes // CHECK-SAME: %[[ARG0:.+]]: tensor // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}[0, 1, 2], [3]] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}[0, 1], [2, 3]] // CHECK: return %[[EXPAND]] // ----- func.func @no_bubble_0d_tensor_reshapes(%arg0: tensor<1x1xf32>) -> tensor<1x1x1xf32> { %collapse = tensor.collapse_shape %arg0 [] : tensor<1x1xf32> into tensor %expand = tensor.expand_shape %collapse [] output_shape [1, 1, 1] : tensor into tensor<1x1x1xf32> return %expand : tensor<1x1x1xf32> } // CHECK: func @no_bubble_0d_tensor_reshapes // CHECK-SAME: %[[ARG0:.+]]: tensor<1x1xf32> // CHECK: %[[COLLAPSE:.+]] = tensor.collapse_shape %[[ARG0]] {{\[}}] // CHECK: %[[EXPAND:.+]] = tensor.expand_shape %[[COLLAPSE]] {{\[}}] // CHECK: return %[[EXPAND]] // ----- // Test the case where the reassocation indices in the collapse and expand // are of same size. func.func @bubble_expand_match_non_unit_size_reassocation( %arg0 : tensor<4x?x4x32x4x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> { %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3, 4], [5]] : tensor<4x?x4x32x4x?xf16> into tensor %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32] : tensor into tensor<4x?x4x128x?x32xf16> return %expanded : tensor<4x?x4x128x?x32xf16> } // CHECK: func @bubble_expand_match_non_unit_size_reassocation // CHECK-SAME: %[[ARG0:.+]]: tensor<4x?x4x32x4x?xf16> // CHECK-SAME: %[[ARG1:[a-zA-z0-9]+]]: index // CHECK-SAME: %[[ARG2:[a-zA-z0-9]+]]: index // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape %[[ARG0]] // CHECK-SAME: {{\[}}[0], [1], [2], [3], [4], [5, 6]{{\]}} // CHECK-SAME: [4, %[[ARG1]], 4, 32, 4, %[[ARG2]], 32] // CHECK: %[[COLLAPSED:.+]] = tensor.collapse_shape %[[EXPANDED]] // CHECK-SAME: {{\[}}[0], [1], [2], [3, 4], [5], [6]{{\]}} // CHECK: return %[[COLLAPSED]] // ----- // Test the case where the trailing collapse isnt needed. func.func @no_collapse_generated( %arg0 : tensor<4x?x4x128x?xf16>, %arg1 : index, %arg2 : index) -> tensor<4x?x4x128x?x32xf16> { %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4]] : tensor<4x?x4x128x?xf16> into tensor %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4, 5]] output_shape [4, %arg1, 4, 128, %arg2, 32] : tensor into tensor<4x?x4x128x?x32xf16> return %expanded : tensor<4x?x4x128x?x32xf16> } // CHECK: func @no_collapse_generated // CHECK: %[[EXPANDED:.+]] = tensor.expand_shape // CHECK: return %[[EXPANDED]] // ----- // Test the case where the leading expand isnt needed. func.func @no_expand_generated( %arg0 : tensor<4x?x4x128x?x?x?xf16>, %arg1 : index, %arg2 : index, %arg3 : index) -> tensor<4x?x4x128x?x?xf16> { %collapsed = tensor.collapse_shape %arg0 [[0, 1, 2], [3], [4], [5, 6]] : tensor<4x?x4x128x?x?x?xf16> into tensor %expanded = tensor.expand_shape %collapsed [[0, 1, 2], [3], [4], [5]] output_shape [4, %arg1, 4, 128, %arg2, %arg3] : tensor into tensor<4x?x4x128x?x?xf16> return %expanded : tensor<4x?x4x128x?x?xf16> } // CHECK: func @no_expand_generated // CHECK: %[[EXPANDED:.+]] = tensor.collapse_shape // CHECK: return %[[EXPANDED]]