summaryrefslogtreecommitdiff
path: root/mlir/test/python/python_pass.py
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python/python_pass.py')
-rw-r--r--mlir/test/python/python_pass.py20
1 files changed, 17 insertions, 3 deletions
diff --git a/mlir/test/python/python_pass.py b/mlir/test/python/python_pass.py
index c94f96e20966..50c42102f66d 100644
--- a/mlir/test/python/python_pass.py
+++ b/mlir/test/python/python_pass.py
@@ -64,12 +64,12 @@ def testCustomPass():
"""
)
- def custom_pass_1(op):
+ def custom_pass_1(op, pass_):
print("hello from pass 1!!!", file=sys.stderr)
class CustomPass2:
- def __call__(self, m):
- apply_patterns_and_fold_greedily(m, frozen)
+ def __call__(self, op, pass_):
+ apply_patterns_and_fold_greedily(op, frozen)
custom_pass_2 = CustomPass2()
@@ -86,3 +86,17 @@ def testCustomPass():
# CHECK: llvm.mul
pm.add("convert-arith-to-llvm")
pm.run(module)
+
+ # test signal_pass_failure
+ def custom_pass_that_fails(op, pass_):
+ print("hello from pass that fails")
+ pass_.signal_pass_failure()
+
+ pm = PassManager("any")
+ pm.add(custom_pass_that_fails, "CustomPassThatFails")
+ # CHECK: hello from pass that fails
+ # CHECK: caught exception: Failure while executing pass pipeline
+ try:
+ pm.run(module)
+ except Exception as e:
+ print(f"caught exception: {e}")