summaryrefslogtreecommitdiff
path: root/mlir/test/python/ir/dialects.py
diff options
context:
space:
mode:
authorNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
committerNAKAMURA Takumi <geek4civic@gmail.com>2025-01-09 18:49:54 +0900
commite2810c9a248f4c7fbfae84bb32b6f7e01027458b (patch)
treeae0b02a8491b969a1cee94ea16ffe42c559143c5 /mlir/test/python/ir/dialects.py
parentfa04eb4af95c1ca7377279728cb004bcd2324d01 (diff)
parentbdcf47e4bcb92889665825654bb80a8bbe30379e (diff)
Merge branch 'users/chapuni/cov/single/base' into users/chapuni/cov/single/switchusers/chapuni/cov/single/switch
Diffstat (limited to 'mlir/test/python/ir/dialects.py')
-rw-r--r--mlir/test/python/ir/dialects.py36
1 files changed, 36 insertions, 0 deletions
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index d59c6a6bc424..5a2ed684d298 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -121,3 +121,39 @@ def testAppendPrefixSearchPath():
sys.path.append(".")
_cext.globals.append_dialect_search_prefix("custom_dialect")
assert _cext.globals._check_dialect_module_loaded("custom")
+
+
+# CHECK-LABEL: TEST: testDialectLoadOnCreate
+@run
+def testDialectLoadOnCreate():
+ with Context(load_on_create_dialects=[]) as ctx:
+ ctx.emit_error_diagnostics = True
+ ctx.allow_unregistered_dialects = True
+
+ def callback(d):
+ # CHECK: DIAGNOSTIC
+ # CHECK-SAME: op created with unregistered dialect
+ print(f"DIAGNOSTIC={d.message}")
+ return True
+
+ handler = ctx.attach_diagnostic_handler(callback)
+ loc = Location.unknown(ctx)
+ try:
+ op = Operation.create("arith.addi", loc=loc)
+ ctx.allow_unregistered_dialects = False
+ op.verify()
+ except MLIRError as e:
+ pass
+
+ with Context(load_on_create_dialects=["func"]) as ctx:
+ loc = Location.unknown(ctx)
+ fn = Operation.create("func.func", loc=loc)
+
+ # TODO: This may require an update if a site wide policy is set.
+ # CHECK: Load on create: []
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ append_load_on_create_dialect("func")
+ # CHECK: Load on create:
+ # CHECK-SAME: func
+ print(f"Load on create: {get_load_on_create_dialects()}")
+ print(get_load_on_create_dialects())