diff options
Diffstat (limited to 'mlir/test/python')
| -rw-r--r-- | mlir/test/python/dialects/transform_structured_ext.py | 36 | ||||
| -rw-r--r-- | mlir/test/python/execution_engine.py | 2 | ||||
| -rw-r--r-- | mlir/test/python/ir/dialects.py | 36 |
3 files changed, 73 insertions, 1 deletions
diff --git a/mlir/test/python/dialects/transform_structured_ext.py b/mlir/test/python/dialects/transform_structured_ext.py index fb4c75b53379..8785d6d36007 100644 --- a/mlir/test/python/dialects/transform_structured_ext.py +++ b/mlir/test/python/dialects/transform_structured_ext.py @@ -103,6 +103,42 @@ def testFuseIntoContainingOpCompact(target): @run @create_sequence +def testFuseOpCompact(target): + structured.FuseOp( + target, tile_sizes=[4, 8], tile_interchange=[0, 1], apply_cleanup=True + ) + # CHECK-LABEL: TEST: testFuseOpCompact + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] apply_cleanup = true + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence +def testFuseOpNoArg(target): + structured.FuseOp(target) + # CHECK-LABEL: TEST: testFuseOpNoArg + # CHECK: transform.sequence + # CHECK: %{{.+}} = transform.structured.fuse %{{.*}} : + # CHECK-SAME: (!transform.any_op) -> !transform.any_op + + +@run +@create_sequence +def testFuseOpAttributes(target): + attr = DenseI64ArrayAttr.get([4, 8]) + ichange = DenseI64ArrayAttr.get([0, 1]) + structured.FuseOp(target, tile_sizes=attr, tile_interchange=ichange) + # CHECK-LABEL: TEST: testFuseOpAttributes + # CHECK: transform.sequence + # CHECK: %{{.+}}, %{{.+}}:2 = transform.structured.fuse %{{.*}}[4, 8] + # CHECK-SAME: interchange [0, 1] + # CHECK-SAME: (!transform.any_op) -> (!transform.any_op, !transform.any_op, !transform.any_op) + + +@run +@create_sequence def testGeneralize(target): structured.GeneralizeOp(target) # CHECK-LABEL: TEST: testGeneralize diff --git a/mlir/test/python/execution_engine.py b/mlir/test/python/execution_engine.py index 6d3a8db8c24b..0d12c35d96be 100644 --- a/mlir/test/python/execution_engine.py +++ b/mlir/test/python/execution_engine.py @@ -306,7 +306,7 @@ def testUnrankedMemRefWithOffsetCallback(): log(arr) with Context(): - # The module takes a subview of the argument memref, casts it to an unranked memref and + # The module takes a subview of the argument memref, casts it to an unranked memref and # calls the callback with it. module = Module.parse( r""" diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index d59c6a6bc424..5a2ed684d298 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -121,3 +121,39 @@ def testAppendPrefixSearchPath(): sys.path.append(".") _cext.globals.append_dialect_search_prefix("custom_dialect") assert _cext.globals._check_dialect_module_loaded("custom") + + +# CHECK-LABEL: TEST: testDialectLoadOnCreate +@run +def testDialectLoadOnCreate(): + with Context(load_on_create_dialects=[]) as ctx: + ctx.emit_error_diagnostics = True + ctx.allow_unregistered_dialects = True + + def callback(d): + # CHECK: DIAGNOSTIC + # CHECK-SAME: op created with unregistered dialect + print(f"DIAGNOSTIC={d.message}") + return True + + handler = ctx.attach_diagnostic_handler(callback) + loc = Location.unknown(ctx) + try: + op = Operation.create("arith.addi", loc=loc) + ctx.allow_unregistered_dialects = False + op.verify() + except MLIRError as e: + pass + + with Context(load_on_create_dialects=["func"]) as ctx: + loc = Location.unknown(ctx) + fn = Operation.create("func.func", loc=loc) + + # TODO: This may require an update if a site wide policy is set. + # CHECK: Load on create: [] + print(f"Load on create: {get_load_on_create_dialects()}") + append_load_on_create_dialect("func") + # CHECK: Load on create: + # CHECK-SAME: func + print(f"Load on create: {get_load_on_create_dialects()}") + print(get_load_on_create_dialects()) |
