diff options
| author | Matthias Springer <mspringer@nvidia.com> | 2024-11-10 08:29:29 +0100 |
|---|---|---|
| committer | Matthias Springer <mspringer@nvidia.com> | 2024-11-10 08:36:15 +0100 |
| commit | d17454e081e167ab391ae1b9f21115591418e614 (patch) | |
| tree | a7e907b9d52638e2c9bfbf6980a8237d6ff6434d | |
| parent | e4c14190bb097162e15cd5822b3de97ea7bac0d6 (diff) | |
[mlir][Transforms] CSE: Add filter options to control CSE'ingusers/matthias-springer/cse_filter
This commit adds two new pass options that gives users more fine-grained control over which ops are CSE'd / DCE'd.
* `barrier-op-filter` specifies ops that should act as CSE'ing barriers. I.e., ops that are nested inside such ops should not be CSE'd with ops that are outside of such ops. (Until now, the only CSE'ing barrier used to be IsolatedFromAbove ops.)
* `eliminate-op-filter` specifies ops that are subject to elimination. All non-matching ops are ignored by the CSE pass and remain in place. (If the filter is empty, all ops are subject to elimination.)
| -rw-r--r-- | mlir/include/mlir/Transforms/CSE.h | 27 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/Passes.h | 2 | ||||
| -rw-r--r-- | mlir/include/mlir/Transforms/Passes.td | 17 | ||||
| -rw-r--r-- | mlir/lib/Transforms/CSE.cpp | 47 | ||||
| -rw-r--r-- | mlir/test/Transforms/cse.mlir | 77 |
5 files changed, 153 insertions, 17 deletions
diff --git a/mlir/include/mlir/Transforms/CSE.h b/mlir/include/mlir/Transforms/CSE.h index 3d01ece07805..4edca3e3369f 100644 --- a/mlir/include/mlir/Transforms/CSE.h +++ b/mlir/include/mlir/Transforms/CSE.h @@ -13,19 +13,44 @@ #ifndef MLIR_TRANSFORMS_CSE_H_ #define MLIR_TRANSFORMS_CSE_H_ +#include <functional> + namespace mlir { class DominanceInfo; class Operation; class RewriterBase; +/// Configuration for CSE. +struct CSEConfig { + /// If set, matching ops act as a CSE'ing barrier: ops are not CSE'd across + /// matching ops. + /// + /// Note: IsolatedFromAbove ops are always a CSE'ing barrier, regardless of + /// this filter. + /// + /// Example: + /// %0 = arith.constant 0 : index + /// scf.for ... { + /// %1 = arith.constant 0 : index + /// ... + /// } + /// If "scf.for" is marked as a CSE'ing barrier, %0 and %1 are *not* CSE'd. + std::function<bool(Operation *)> barrierOpFilter = nullptr; + + /// If set, matching ops are not eliminated (neither CSE'd nor DCE'd). All + /// non-matching ops are subject to elimination. + std::function<bool(Operation *)> eliminateOpFilter = nullptr; +}; + /// Eliminate common subexpressions within the given operation. This transform /// looks for and deduplicates equivalent operations. /// /// `changed` indicates whether the IR was modified or not. void eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, - bool *changed = nullptr); + bool *changed = nullptr, + CSEConfig config = CSEConfig()); } // namespace mlir diff --git a/mlir/include/mlir/Transforms/Passes.h b/mlir/include/mlir/Transforms/Passes.h index 5c977055e95d..41f208216374 100644 --- a/mlir/include/mlir/Transforms/Passes.h +++ b/mlir/include/mlir/Transforms/Passes.h @@ -33,7 +33,7 @@ class GreedyRewriteConfig; #define GEN_PASS_DECL_CANONICALIZER #define GEN_PASS_DECL_CONTROLFLOWSINK -#define GEN_PASS_DECL_CSEPASS +#define GEN_PASS_DECL_CSE #define GEN_PASS_DECL_INLINER #define GEN_PASS_DECL_LOOPINVARIANTCODEMOTION #define GEN_PASS_DECL_MEM2REG diff --git a/mlir/include/mlir/Transforms/Passes.td b/mlir/include/mlir/Transforms/Passes.td index 000d9f697618..429029f21eb3 100644 --- a/mlir/include/mlir/Transforms/Passes.td +++ b/mlir/include/mlir/Transforms/Passes.td @@ -81,12 +81,25 @@ def CSE : Pass<"cse"> { let summary = "Eliminate common sub-expressions"; let description = [{ This pass implements a generalized algorithm for common sub-expression - elimination. This pass relies on information provided by the - `Memory SideEffect` interface to identify when it is safe to eliminate + elimination. The pass also eliminates dead operation (DCE). The pass + relies on information provided by the `MemoryEffectOpInterface` + interface and on `DominanceInfo` to identify when it is safe to eliminate operations. See [Common subexpression elimination](https://en.wikipedia.org/wiki/Common_subexpression_elimination) for more general details on this optimization. + + The types of ops that are subject to elimination can be configured with + `eliminate-op-filter`. If set, only those ops are CSE'd or DCE'd. + + Ops are never CSE'd across IsolatedFromAbove ops. Additional CSE'ing + barrier ops can be specified with `barrier-op-filter`. }]; let constructor = "mlir::createCSEPass()"; + let options = [ + ListOption<"barrierOpFilter", "barrier-op-filter", "std::string", + "Names of ops that act as CSE'ing barriers">, + ListOption<"eliminateOpFilter", "eliminate-op-filter", "std::string", + "If non-empty, list of ops that are subject to elimination">, + ]; let statistics = [ Statistic<"numCSE", "num-cse'd", "Number of operations CSE'd">, Statistic<"numDCE", "num-dce'd", "Number of operations DCE'd"> diff --git a/mlir/lib/Transforms/CSE.cpp b/mlir/lib/Transforms/CSE.cpp index 3affd88d158d..93ac35db276d 100644 --- a/mlir/lib/Transforms/CSE.cpp +++ b/mlir/lib/Transforms/CSE.cpp @@ -23,8 +23,9 @@ #include "llvm/ADT/ScopedHashTable.h" #include "llvm/Support/Allocator.h" #include "llvm/Support/RecyclingAllocator.h" -#include <deque> +#include <deque> +#include <unordered_set> namespace mlir { #define GEN_PASS_DEF_CSE #include "mlir/Transforms/Passes.h.inc" @@ -60,8 +61,9 @@ namespace { /// Simple common sub-expression elimination. class CSEDriver { public: - CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo) - : rewriter(rewriter), domInfo(domInfo) {} + CSEDriver(RewriterBase &rewriter, DominanceInfo *domInfo, + const CSEConfig &config) + : rewriter(rewriter), domInfo(domInfo), config(config) {} /// Simplify all operations within the given op. void simplify(Operation *op, bool *changed = nullptr); @@ -125,6 +127,9 @@ private: // Various statistics. int64_t numCSE = 0; int64_t numDCE = 0; + + /// CSE configuration. + CSEConfig config; }; } // namespace @@ -226,6 +231,10 @@ bool CSEDriver::hasOtherSideEffectingOpInBetween(Operation *fromOp, LogicalResult CSEDriver::simplifyOperation(ScopedMapTy &knownValues, Operation *op, bool hasSSADominance) { + // Don't simplify operations that are filtered out. + if (config.eliminateOpFilter && !config.eliminateOpFilter(op)) + return failure(); + // Don't simplify terminator operations. if (op->hasTrait<OpTrait::IsTerminator>()) return failure(); @@ -288,8 +297,11 @@ void CSEDriver::simplifyBlock(ScopedMapTy &knownValues, Block *bb, if (op.getNumRegions() != 0) { // If this operation is isolated above, we can't process nested regions // with the given 'knownValues' map. This would cause the insertion of - // implicit captures in explicit capture only regions. - if (op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>()) { + // implicit captures in explicit capture only regions. Additional barrier + // ops can be specified by the user. + bool isBarrier = op.mightHaveTrait<OpTrait::IsIsolatedFromAbove>() || + (config.barrierOpFilter && config.barrierOpFilter(&op)); + if (isBarrier) { ScopedMapTy nestedKnownValues; for (auto ®ion : op.getRegions()) simplifyRegion(nestedKnownValues, region); @@ -381,8 +393,8 @@ void CSEDriver::simplify(Operation *op, bool *changed) { void mlir::eliminateCommonSubExpressions(RewriterBase &rewriter, DominanceInfo &domInfo, Operation *op, - bool *changed) { - CSEDriver driver(rewriter, &domInfo); + bool *changed, CSEConfig config) { + CSEDriver driver(rewriter, &domInfo, config); driver.simplify(op, changed); } @@ -394,9 +406,28 @@ struct CSE : public impl::CSEBase<CSE> { } // namespace void CSE::runOnOperation() { + // Set up CSE configuration from pass options. + CSEConfig config; + std::unordered_set<std::string> barrierOpNames; + for (std::string opName : barrierOpFilter) + barrierOpNames.insert(opName); + std::unordered_set<std::string> eliminateOpNames; + for (std::string opName : eliminateOpFilter) + eliminateOpNames.insert(opName); + if (!barrierOpNames.empty()) { + config.barrierOpFilter = [&](Operation *op) -> bool { + return barrierOpNames.count(op->getName().getStringRef().str()); + }; + } + if (!eliminateOpNames.empty()) { + config.eliminateOpFilter = [&](Operation *op) -> bool { + return eliminateOpNames.count(op->getName().getStringRef().str()); + }; + } + // Simplify the IR. IRRewriter rewriter(&getContext()); - CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>()); + CSEDriver driver(rewriter, &getAnalysis<DominanceInfo>(), config); bool changed = false; driver.simplify(getOperation(), &changed); diff --git a/mlir/test/Transforms/cse.mlir b/mlir/test/Transforms/cse.mlir index 11a331026847..5d2da75db6ce 100644 --- a/mlir/test/Transforms/cse.mlir +++ b/mlir/test/Transforms/cse.mlir @@ -1,32 +1,47 @@ -// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' | FileCheck %s - -// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> -#map0 = affine_map<(d0) -> (d0 mod 2)> +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse))' -split-input-file | FileCheck %s +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="eliminate-op-filter=arith.constant"))' -split-input-file | FileCheck %s --check-prefix=CHECK-ELIMINATE-FILTER +// RUN: mlir-opt -allow-unregistered-dialect %s -pass-pipeline='builtin.module(func.func(cse="barrier-op-filter=affine.for"))' -split-input-file | FileCheck %s --check-prefix=CHECK-BARRIER-FILTER // CHECK-LABEL: @simple_constant +// CHECK-ELIMINATE-FILTER-LABEL: @simple_constant func.func @simple_constant() -> (i32, i32) { // CHECK-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32 + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_c1_i32:.*]] = arith.constant 1 : i32 %0 = arith.constant 1 : i32 // CHECK-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32 + // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_c1_i32]], %[[VAR_c1_i32]] : i32, i32 %1 = arith.constant 1 : i32 return %0, %1 : i32, i32 } +// ----- + +// CHECK-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> +// CHECK-ELIMINATE-FILTER-DAG: #[[$MAP:.*]] = affine_map<(d0) -> (d0 mod 2)> +#map0 = affine_map<(d0) -> (d0 mod 2)> + // CHECK-LABEL: @basic +// CHECK-ELIMINATE-FILTER-LABEL: @basic func.func @basic() -> (index, index) { // CHECK: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index + // CHECK-ELIMINATE-FILTER: %[[VAR_c0:[0-9a-zA-Z_]+]] = arith.constant 0 : index %c0 = arith.constant 0 : index %c1 = arith.constant 0 : index // CHECK-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_0:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) + // CHECK-ELIMINATE-FILTER-NEXT: %[[VAR_1:[0-9a-zA-Z_]+]] = affine.apply #[[$MAP]](%[[VAR_c0]]) %0 = affine.apply #map0(%c0) %1 = affine.apply #map0(%c1) // CHECK-NEXT: return %[[VAR_0]], %[[VAR_0]] : index, index + // CHECK-ELIMINATE-FILTER-NEXT: return %[[VAR_0]], %[[VAR_1]] : index, index return %0, %1 : index, index } +// ----- + // CHECK-LABEL: @many func.func @many(f32, f32) -> (f32) { ^bb0(%a : f32, %b : f32): @@ -52,6 +67,8 @@ func.func @many(f32, f32) -> (f32) { return %l : f32 } +// ----- + /// Check that operations are not eliminated if they have different operands. // CHECK-LABEL: @different_ops func.func @different_ops() -> (i32, i32) { @@ -64,6 +81,8 @@ func.func @different_ops() -> (i32, i32) { return %0, %1 : i32, i32 } +// ----- + /// Check that operations are not eliminated if they have different result /// types. // CHECK-LABEL: @different_results @@ -77,6 +96,8 @@ func.func @different_results(%arg0: tensor<*xf32>) -> (tensor<?x?xf32>, tensor<4 return %0, %1 : tensor<?x?xf32>, tensor<4x?xf32> } +// ----- + /// Check that operations are not eliminated if they have different attributes. // CHECK-LABEL: @different_attributes func.func @different_attributes(index, index) -> (i1, i1, i1) { @@ -93,6 +114,8 @@ func.func @different_attributes(index, index) -> (i1, i1, i1) { return %0, %1, %2 : i1, i1, i1 } +// ----- + /// Check that operations with side effects are not eliminated. // CHECK-LABEL: @side_effect func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) { @@ -106,22 +129,32 @@ func.func @side_effect() -> (memref<2x1xf32>, memref<2x1xf32>) { return %0, %1 : memref<2x1xf32>, memref<2x1xf32> } +// ----- + /// Check that operation definitions are properly propagated down the dominance /// tree. // CHECK-LABEL: @down_propagate_for +// CHECK-BARRIER-FILTER-LABEL: @down_propagate_for func.func @down_propagate_for() { // CHECK: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 + // CHECK-BARRIER-FILTER: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %0 = arith.constant 1 : i32 // CHECK-NEXT: affine.for {{.*}} = 0 to 4 { + // CHECK-BARRIER-FILTER-NEXT: affine.for {{.*}} = 0 to 4 { affine.for %i = 0 to 4 { - // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> () + // CHECK-BARRIER-FILTER-NEXT: %[[VAR2_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 %1 = arith.constant 1 : i32 + + // CHECK-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR_c1_i32]]) : (i32, i32) -> () + // CHECK-BARRIER-FILTER-NEXT: "foo"(%[[VAR_c1_i32]], %[[VAR2_c1_i32]]) : (i32, i32) -> () "foo"(%0, %1) : (i32, i32) -> () } return } +// ----- + // CHECK-LABEL: @down_propagate func.func @down_propagate() -> i32 { // CHECK-NEXT: %[[VAR_c1_i32:[0-9a-zA-Z_]+]] = arith.constant 1 : i32 @@ -142,6 +175,8 @@ func.func @down_propagate() -> i32 { return %arg : i32 } +// ----- + /// Check that operation definitions are NOT propagated up the dominance tree. // CHECK-LABEL: @up_propagate_for func.func @up_propagate_for() -> i32 { @@ -159,6 +194,8 @@ func.func @up_propagate_for() -> i32 { return %1 : i32 } +// ----- + // CHECK-LABEL: func @up_propagate func.func @up_propagate() -> i32 { // CHECK-NEXT: %[[VAR_c0_i32:[0-9a-zA-Z_]+]] = arith.constant 0 : i32 @@ -188,6 +225,8 @@ func.func @up_propagate() -> i32 { return %add : i32 } +// ----- + /// The same test as above except that we are testing on a cfg embedded within /// an operation region. // CHECK-LABEL: func @up_propagate_region @@ -221,6 +260,8 @@ func.func @up_propagate_region() -> i32 { return %0 : i32 } +// ----- + /// This test checks that nested regions that are isolated from above are /// properly handled. // CHECK-LABEL: @nested_isolated @@ -248,6 +289,8 @@ func.func @nested_isolated() -> i32 { return %0 : i32 } +// ----- + /// This test is checking that CSE gracefully handles values in graph regions /// where the use occurs before the def, and one of the defs could be CSE'd with /// the other. @@ -269,6 +312,8 @@ func.func @use_before_def() { return } +// ----- + /// This test is checking that CSE is removing duplicated read op that follow /// other. // CHECK-LABEL: @remove_direct_duplicated_read_op @@ -281,6 +326,8 @@ func.func @remove_direct_duplicated_read_op() -> i32 { return %2 : i32 } +// ----- + /// This test is checking that CSE is removing duplicated read op that follow /// other. // CHECK-LABEL: @remove_multiple_duplicated_read_op @@ -300,6 +347,8 @@ func.func @remove_multiple_duplicated_read_op() -> i64 { return %6 : i64 } +// ----- + /// This test is checking that CSE is not removing duplicated read op that /// have write op in between. // CHECK-LABEL: @dont_remove_duplicated_read_op_with_sideeffecting @@ -314,6 +363,8 @@ func.func @dont_remove_duplicated_read_op_with_sideeffecting() -> i32 { return %2 : i32 } +// ----- + // Check that an operation with a single region can CSE. func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { @@ -332,6 +383,8 @@ func.func @cse_single_block_ops(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>) // CHECK-NOT: test.cse_of_single_block_op // CHECK: return %[[OP]], %[[OP]] +// ----- + // Operations with different number of bbArgs dont CSE. func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { @@ -350,6 +403,8 @@ func.func @no_cse_varied_bbargs(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>) // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op // CHECK: return %[[OP0]], %[[OP1]] +// ----- + // Operations with different regions dont CSE func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>) -> (tensor<?x?xf32>, tensor<?x?xf32>) { @@ -368,6 +423,8 @@ func.func @no_cse_region_difference_simple(%a : tensor<?x?xf32>, %b : tensor<?x? // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op // CHECK: return %[[OP0]], %[[OP1]] +// ----- + // Operation with identical region with multiple statements CSE. func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1) -> (tensor<?x?xf32>, tensor<?x?xf32>) { @@ -392,6 +449,8 @@ func.func @cse_single_block_ops_identical_bodies(%a : tensor<?x?xf32>, %b : tens // CHECK-NOT: test.cse_of_single_block_op // CHECK: return %[[OP]], %[[OP]] +// ----- + // Operation with non-identical regions dont CSE. func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : tensor<?x?xf32>, %c : f32, %d : i1) -> (tensor<?x?xf32>, tensor<?x?xf32>) { @@ -416,6 +475,8 @@ func.func @no_cse_single_block_ops_different_bodies(%a : tensor<?x?xf32>, %b : t // CHECK: %[[OP1:.+]] = test.cse_of_single_block_op // CHECK: return %[[OP0]], %[[OP1]] +// ----- + func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor<2xi1>) -> (tensor<2xi1>, tensor<2xi1>) { %false_2 = arith.constant false %true_5 = arith.constant true @@ -438,6 +499,8 @@ func.func @failing_issue_59135(%arg0: tensor<2x2xi1>, %arg1: f32, %arg2 : tensor // CHECK: test.region_yield %[[TRUE]] // CHECK: return %[[OP]], %[[OP]] +// ----- + func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, tensor<5xf32>) { %r1 = scf.if %c -> (tensor<5xf32>) { %0 = tensor.empty() : tensor<5xf32> @@ -463,6 +526,8 @@ func.func @cse_multiple_regions(%c: i1, %t: tensor<5xf32>) -> (tensor<5xf32>, te // CHECK-NOT: scf.if // CHECK: return %[[if]], %[[if]] +// ----- + // CHECK-LABEL: @cse_recursive_effects_success func.func @cse_recursive_effects_success() -> (i32, i32, i32) { // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 @@ -492,6 +557,8 @@ func.func @cse_recursive_effects_success() -> (i32, i32, i32) { return %0, %2, %1 : i32, i32, i32 } +// ----- + // CHECK-LABEL: @cse_recursive_effects_failure func.func @cse_recursive_effects_failure() -> (i32, i32, i32) { // CHECK-NEXT: %[[READ_VALUE:.*]] = "test.op_with_memread"() : () -> i32 |
