diff options
| author | Florian Mayer <fmayer@google.com> | 2025-10-13 13:23:20 -0700 |
|---|---|---|
| committer | Florian Mayer <fmayer@google.com> | 2025-10-13 13:23:20 -0700 |
| commit | ccc6fad8951b12dbe5fde7d8d00b9c959e36fa00 (patch) | |
| tree | b17aad8d549af17c7f8411ecd5ccc6ecfd4b1499 /mlir/test/python/rewrite.py | |
| parent | 82a427702eb2c83ff571670a73071420f0b31546 (diff) | |
| parent | 55d4e92c8821d5543469118a76fe38db866377b7 (diff) | |
[𝘀𝗽𝗿] changes introduced through rebaseusers/fmayer/spr/main.flowsensitive-statusor-2n-add-minimal-model
Created using spr 1.3.4
[skip ci]
Diffstat (limited to 'mlir/test/python/rewrite.py')
| -rw-r--r-- | mlir/test/python/rewrite.py | 70 |
1 files changed, 70 insertions, 0 deletions
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py new file mode 100644 index 000000000000..821e47085a5b --- /dev/null +++ b/mlir/test/python/rewrite.py @@ -0,0 +1,70 @@ +# RUN: %PYTHON %s 2>&1 | FileCheck %s + +from mlir.ir import * +from mlir.passmanager import * +from mlir.dialects.builtin import ModuleOp +from mlir.dialects import arith +from mlir.rewrite import * + + +def run(f): + print("\nTEST:", f.__name__) + f() + + +# CHECK-LABEL: TEST: testRewritePattern +@run +def testRewritePattern(): + def to_muli(op, rewriter): + with rewriter.ip: + assert isinstance(op, arith.AddIOp) + new_op = arith.muli(op.lhs, op.rhs, loc=op.location) + rewriter.replace_op(op, new_op.owner) + + def constant_1_to_2(op, rewriter): + c = op.value.value + if c != 1: + return True # failed to match + with rewriter.ip: + new_op = arith.constant(op.type, 2, loc=op.location) + rewriter.replace_op(op, [new_op]) + + with Context(): + patterns = RewritePatternSet() + patterns.add(arith.AddIOp, to_muli) + patterns.add(arith.ConstantOp, constant_1_to_2) + frozen = patterns.freeze() + + module = ModuleOp.parse( + r""" + module { + func.func @add(%a: i64, %b: i64) -> i64 { + %sum = arith.addi %a, %b : i64 + return %sum : i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %0 = arith.muli %arg0, %arg1 : i64 + # CHECK: return %0 : i64 + print(module) + + module = ModuleOp.parse( + r""" + module { + func.func @const() -> (i64, i64) { + %0 = arith.constant 1 : i64 + %1 = arith.constant 3 : i64 + return %0, %1 : i64, i64 + } + } + """ + ) + + apply_patterns_and_fold_greedily(module, frozen) + # CHECK: %c2_i64 = arith.constant 2 : i64 + # CHECK: %c3_i64 = arith.constant 3 : i64 + # CHECK: return %c2_i64, %c3_i64 : i64, i64 + print(module) |
