summaryrefslogtreecommitdiff
path: root/mlir/test/python/python_pass.py
diff options
context:
space:
mode:
authorTwice <twice@apache.org>2025-09-09 09:01:23 +0800
committerGitHub <noreply@github.com>2025-09-08 18:01:23 -0700
commit7d04e3790483f8e2f168ec538682bc31d3011a0b (patch)
tree0c0f255cade6a3f1e36dfe9223032864484d0b90 /mlir/test/python/python_pass.py
parent82ef4ee725a459c137dd1f5419cf26deb15a14c8 (diff)
[MLIR][Python] Support Python-defined passes in MLIR (#156000)
It closes #155996. This PR added a method `add(callable, ..)` to `mlir.passmanager.PassManager` to accept a callable object for defining passes in the Python side. This is a simple example of a Python-defined pass. ```python from mlir.passmanager import PassManager def demo_pass_1(op): # do something with op pass class DemoPass: def __init__(self, ...): pass def __call__(op): # do something pass demo_pass_2 = DemoPass(..) pm = PassManager('any', ctx) pm.add(demo_pass_1) pm.add(demo_pass_2) pm.add("registered-passes") pm.run(..) ``` --------- Co-authored-by: cnb.bsD2OPwAgEA <QejD2DJ2eEahUVy6Zg0aZI+cnb.bsD2OPwAgEA@noreply.cnb.cool> Co-authored-by: Maksim Levental <maksim.levental@gmail.com>
Diffstat (limited to 'mlir/test/python/python_pass.py')
-rw-r--r--mlir/test/python/python_pass.py88
1 files changed, 88 insertions, 0 deletions
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
new file mode 100644
index 000000000000..c94f96e20966
--- /dev/null
+++ b/mlir/test/python/python_pass.py
@@ -0,0 +1,88 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+import gc, sys
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import pdl
+from mlir.rewrite import *
+
+
+def log(*args):
+ print(*args, file=sys.stderr)
+ sys.stderr.flush()
+
+
+def run(f):
+ log("\nTEST:", f.__name__)
+ f()
+ gc.collect()
+ assert Context._get_live_count() == 0
+
+
+def make_pdl_module():
+ with Location.unknown():
+ pdl_module = Module.create()
+ with InsertionPoint(pdl_module.body):
+ # Change all arith.addi with index types to arith.muli.
+ @pdl.pattern(benefit=1, sym_name="addi_to_mul")
+ def pat():
+ # Match arith.addi with index types.
+ i64_type = pdl.TypeOp(IntegerType.get_signless(64))
+ operand0 = pdl.OperandOp(i64_type)
+ operand1 = pdl.OperandOp(i64_type)
+ op0 = pdl.OperationOp(
+ name="arith.addi", args=[operand0, operand1], types=[i64_type]
+ )
+
+ # Replace the matched op with arith.muli.
+ @pdl.rewrite()
+ def rew():
+ newOp = pdl.OperationOp(
+ name="arith.muli", args=[operand0, operand1], types=[i64_type]
+ )
+ pdl.ReplaceOp(op0, with_op=newOp)
+
+ return pdl_module
+
+
+# CHECK-LABEL: TEST: testCustomPass
+@run
+def testCustomPass():
+ with Context():
+ pdl_module = make_pdl_module()
+ frozen = PDLModule(pdl_module).freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ def custom_pass_1(op):
+ print("hello from pass 1!!!", file=sys.stderr)
+
+ class CustomPass2:
+ def __call__(self, m):
+ apply_patterns_and_fold_greedily(m, frozen)
+
+ custom_pass_2 = CustomPass2()
+
+ pm = PassManager("any")
+ pm.enable_ir_printing()
+
+ # CHECK: hello from pass 1!!!
+ # CHECK-LABEL: Dump After custom_pass_1
+ pm.add(custom_pass_1)
+ # CHECK-LABEL: Dump After CustomPass2
+ # CHECK: arith.muli
+ pm.add(custom_pass_2, "CustomPass2")
+ # CHECK-LABEL: Dump After ArithToLLVMConversionPass
+ # CHECK: llvm.mul
+ pm.add("convert-arith-to-llvm")
+ pm.run(module)