summaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py60
-rw-r--r--mlir/test/python/rewrite.py70
2 files changed, 127 insertions, 3 deletions
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index 8785d6d36007..d6b70dc9d197 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -109,13 +109,29 @@ def testFuseOpCompact(target):
)
# CHECK-LABEL: TEST: testFuseOpCompact
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
- # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: interchange [0, 1] {apply_cleanup}
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
@run
@create_sequence
+def testFuseOpCompactForall(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[4, 8],
+ apply_cleanup=True,
+ use_forall=True,
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}} = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
+ # CHECK-SAME: {apply_cleanup, use_forall}
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpNoArg(target):
structured.FuseOp(target)
# CHECK-LABEL: TEST: testFuseOpNoArg
@@ -126,13 +142,51 @@ def testFuseOpNoArg(target):
@run
@create_sequence
+def testFuseOpParams(target):
+ structured.FuseOp(
+ target,
+ tile_sizes=[constant_param(4), Attribute.parse("8")],
+ tile_interchange=[constant_param(0), Attribute.parse("1")],
+ )
+ # CHECK-LABEL: TEST: testFuseOpParams
+ # CHECK: transform.sequence
+ # CHECK-DAG: %[[P:.*]] = transform.param.constant 4
+ # CHECK-DAG: %[[I:.*]] = transform.param.constant 0
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[P]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.param<i64>, !transform.param<i64>) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpHandles(target):
+ size1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ ichange1 = structured.MatchOp.match_op_names(target, ["arith.constant"])
+ structured.FuseOp(
+ target,
+ tile_sizes=[size1, 8],
+ tile_interchange=[ichange1, 1],
+ )
+ # CHECK-LABEL: TEST: testFuseOpHandles
+ # CHECK: transform.sequence
+ # CHECK: %[[H:.*]] = transform.structured.match
+ # CHECK: %[[I:.*]] = transform.structured.match
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse
+ # CHECK-SAME: tile_sizes [%[[H]], 8]
+ # CHECK-SAME: interchange [%[[I]], 1]
+ # CHECK-SAME: (!transform.any_op, !transform.any_op, !transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testFuseOpAttributes(target):
attr = DenseI64ArrayAttr.get([4, 8])
ichange = DenseI64ArrayAttr.get([0, 1])
structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
# CHECK-LABEL: TEST: testFuseOpAttributes
# CHECK: transform.sequence
- # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}} tile_sizes [4, 8]
# CHECK-SAME: interchange [0, 1]
# CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
diff --git a/mlir/test/python/rewrite.py b/mlir/test/python/rewrite.py
new file mode 100644
index 000000000000..821e47085a5b
--- /dev/null
+++ b/mlir/test/python/rewrite.py
@@ -0,0 +1,70 @@
+# RUN: %PYTHON %s 2>&1 | FileCheck %s
+
+from mlir.ir import *
+from mlir.passmanager import *
+from mlir.dialects.builtin import ModuleOp
+from mlir.dialects import arith
+from mlir.rewrite import *
+
+
+def run(f):
+ print("\nTEST:", f.__name__)
+ f()
+
+
+# CHECK-LABEL: TEST: testRewritePattern
+@run
+def testRewritePattern():
+ def to_muli(op, rewriter):
+ with rewriter.ip:
+ assert isinstance(op, arith.AddIOp)
+ new_op = arith.muli(op.lhs, op.rhs, loc=op.location)
+ rewriter.replace_op(op, new_op.owner)
+
+ def constant_1_to_2(op, rewriter):
+ c = op.value.value
+ if c != 1:
+ return True # failed to match
+ with rewriter.ip:
+ new_op = arith.constant(op.type, 2, loc=op.location)
+ rewriter.replace_op(op, [new_op])
+
+ with Context():
+ patterns = RewritePatternSet()
+ patterns.add(arith.AddIOp, to_muli)
+ patterns.add(arith.ConstantOp, constant_1_to_2)
+ frozen = patterns.freeze()
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @add(%a: i64, %b: i64) -> i64 {
+ %sum = arith.addi %a, %b : i64
+ return %sum : i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %0 = arith.muli %arg0, %arg1 : i64
+ # CHECK: return %0 : i64
+ print(module)
+
+ module = ModuleOp.parse(
+ r"""
+ module {
+ func.func @const() -> (i64, i64) {
+ %0 = arith.constant 1 : i64
+ %1 = arith.constant 3 : i64
+ return %0, %1 : i64, i64
+ }
+ }
+ """
+ )
+
+ apply_patterns_and_fold_greedily(module, frozen)
+ # CHECK: %c2_i64 = arith.constant 2 : i64
+ # CHECK: %c3_i64 = arith.constant 3 : i64
+ # CHECK: return %c2_i64, %c3_i64 : i64, i64
+ print(module)