summaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/dialects/transform_structured_ext.py36
-rw-r--r--mlir/test/python/execution_engine.py2
-rw-r--r--mlir/test/python/ir/dialects.py36
3 files changed, 73 insertions, 1 deletions
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py
index fb4c75b53379..8785d6d36007 100644
--- a/mlir/test/python/dialects/transform_structured_ext.py
+++ b/mlir/test/python/dialects/transform_structured_ext.py
@@ -103,6 +103,42 @@ def testFuseIntoContainingOpCompact(target):
@run
@create_sequence
+def testFuseOpCompact(target):
+ structured.FuseOp(
+ target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True
+ )
+ # CHECK-LABEL: TEST: testFuseOpCompact
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1] apply_cleanup = true
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
+def testFuseOpNoArg(target):
+ structured.FuseOp(target)
+ # CHECK-LABEL: TEST: testFuseOpNoArg
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} :
+ # CHECK-SAME: (!transform.any_op) -> !transform.any_op
+
+
+@run
+@create_sequence
+def testFuseOpAttributes(target):
+ attr = DenseI64ArrayAttr.get([4, 8])
+ ichange = DenseI64ArrayAttr.get([0, 1])
+ structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange)
+ # CHECK-LABEL: TEST: testFuseOpAttributes
+ # CHECK: transform.sequence
+ # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8]
+ # CHECK-SAME: interchange [0, 1]
+ # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op)
+
+
+@run
+@create_sequence
def testGeneralize(target):
structured.GeneralizeOp(target)
# CHECK-LABEL: TEST: testGeneralize
diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py
index 6d3a8db8c24b..0d12c35d96be 100644
--- a/mlir/test/python/execution_engine.py
+++ b/mlir/test/python/execution_engine.py
@@ -306,7 +306,7 @@ def testUnrankedMemRefWithOffsetCallback():
log(arr)
with Context():
- # The module takes a subview of the argument memref, casts it to an unranked memref and
+ # The module takes a subview of the argument memref, casts it to an unranked memref and
# calls the callback with it.
module = Module.parse(
r"""
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424..5a2ed684d298 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+@run
+def testDialectLoadOnCreate():
+ with Context(load_on_create_dialects=[]) as ctx:
+ ctx.emit_error_diagnostics = True
+ ctx.allow_unregistered_dialects = True
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC
+ # CHECK-SAME: op created with unregistered dialect
+ print(f"DIAGNOSTIC={d.message}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ try:
+ op = Operation.create("arith.addi", loc=loc)
+ ctx.allow_unregistered_dialects = False
+ op.verify()
+ except MLIRError as e:
+ pass
+
+ with Context(load_on_create_dialects=["func"]) as ctx:
+ loc = Location.unknown(ctx)
+ fn = Operation.create("func.func", loc=loc)
+
+ # TODO: This may require an update if a site wide policy is set.
+ # CHECK: Load on create: []
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ append_load_on_create_dialect("func")
+ # CHECK: Load on create:
+ # CHECK-SAME: func
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ print(get_load_on_create_dialects())