summaryrefslogtreecommitdiff
path: root/mlir/test/python/python_pass.py
blob: 50c42102f66d33c585c3c8b21621c10319526fe8 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
# RUN: %PYTHON %s 2>&1 | FileCheck %s

import gc, sys
from mlir.ir import *
from mlir.passmanager import *
from mlir.dialects.builtin import ModuleOp
from mlir.dialects import pdl
from mlir.rewrite import *


def log(*args):
    print(*args, file=sys.stderr)
    sys.stderr.flush()


def run(f):
    log("\nTEST:", f.__name__)
    f()
    gc.collect()
    assert Context._get_live_count() == 0


def make_pdl_module():
    with Location.unknown():
        pdl_module = Module.create()
        with InsertionPoint(pdl_module.body):
            # Change all arith.addi with index types to arith.muli.
            @pdl.pattern(benefit=1, sym_name="addi_to_mul")
            def pat():
                # Match arith.addi with index types.
                i64_type = pdl.TypeOp(IntegerType.get_signless(64))
                operand0 = pdl.OperandOp(i64_type)
                operand1 = pdl.OperandOp(i64_type)
                op0 = pdl.OperationOp(
                    name="arith.addi", args=[operand0, operand1], types=[i64_type]
                )

                # Replace the matched op with arith.muli.
                @pdl.rewrite()
                def rew():
                    newOp = pdl.OperationOp(
                        name="arith.muli", args=[operand0, operand1], types=[i64_type]
                    )
                    pdl.ReplaceOp(op0, with_op=newOp)

        return pdl_module


# CHECK-LABEL: TEST: testCustomPass
@run
def testCustomPass():
    with Context():
        pdl_module = make_pdl_module()
        frozen = PDLModule(pdl_module).freeze()

        module = ModuleOp.parse(
            r"""
            module {
              func.func @add(%a: i64, %b: i64) -> i64 {
                %sum = arith.addi %a, %b : i64
                return %sum : i64
              }
            }
        """
        )

        def custom_pass_1(op, pass_):
            print("hello from pass 1!!!", file=sys.stderr)

        class CustomPass2:
            def __call__(self, op, pass_):
                apply_patterns_and_fold_greedily(op, frozen)

        custom_pass_2 = CustomPass2()

        pm = PassManager("any")
        pm.enable_ir_printing()

        # CHECK: hello from pass 1!!!
        # CHECK-LABEL: Dump After custom_pass_1
        pm.add(custom_pass_1)
        # CHECK-LABEL: Dump After CustomPass2
        # CHECK: arith.muli
        pm.add(custom_pass_2, "CustomPass2")
        # CHECK-LABEL: Dump After ArithToLLVMConversionPass
        # CHECK: llvm.mul
        pm.add("convert-arith-to-llvm")
        pm.run(module)

        # test signal_pass_failure
        def custom_pass_that_fails(op, pass_):
            print("hello from pass that fails")
            pass_.signal_pass_failure()

        pm = PassManager("any")
        pm.add(custom_pass_that_fails, "CustomPassThatFails")
        # CHECK: hello from pass that fails
        # CHECK: caught exception: Failure while executing pass pipeline
        try:
            pm.run(module)
        except Exception as e:
            print(f"caught exception: {e}")