summaryrefslogtreecommitdiff
path: root/mlir/test/python/ir/module.py
blob: 33959bea9ffb6ce7dd2b2beb5de658dd711cbe72 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# RUN: %PYTHON %s | FileCheck %s

import gc
from tempfile import NamedTemporaryFile
from mlir.ir import *


def run(f):
    print("\nTEST:", f.__name__)
    f()
    gc.collect()
    assert Context._get_live_count() == 0
    return f


# Verify successful parse.
# CHECK-LABEL: TEST: testParseSuccess
# CHECK: module @successfulParse
@run
def testParseSuccess():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert module.context is ctx
    print("CLEAR CONTEXT")
    ctx = None  # Ensure that module captures the context.
    gc.collect()
    module.dump()  # Just outputs to stderr. Verifies that it functions.
    print(str(module))


# Verify successful parse from file.
# CHECK-LABEL: TEST: testParseFromFileSuccess
# CHECK: module @successfulParse
@run
def testParseFromFileSuccess():
    ctx = Context()
    with NamedTemporaryFile(mode="w") as tmp_file:
        tmp_file.write(r"""module @successfulParse {}""")
        tmp_file.flush()
        module = Module.parseFile(tmp_file.name, ctx)
        assert module.context is ctx
        print("CLEAR CONTEXT")
        ctx = None  # Ensure that module captures the context.
        gc.collect()
        module.operation.verify()
        print(str(module))


# Verify parse error.
# CHECK-LABEL: TEST: testParseError
# CHECK: testParseError: <
# CHECK:   Unable to parse module assembly:
# CHECK:   error: "-":1:1: expected operation name in quotes
# CHECK: >
@run
def testParseError():
    ctx = Context()
    try:
        module = Module.parse(r"""}SYNTAX ERROR{""", ctx)
    except MLIRError as e:
        print(f"testParseError: <{e}>")
    else:
        print("Exception not produced")


# Verify successful parse.
# CHECK-LABEL: TEST: testCreateEmpty
# CHECK: module {
@run
def testCreateEmpty():
    ctx = Context()
    loc = Location.unknown(ctx)
    module = Module.create(loc)
    print("CLEAR CONTEXT")
    ctx = None  # Ensure that module captures the context.
    gc.collect()
    print(str(module))


# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
# CHECK-LABEL: TEST: testRoundtripUnicode
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripUnicode():
    ctx = Context()
    module = Module.parse(
        r"""
    func.func private @roundtripUnicode() attributes { foo = "😊" }
  """,
        ctx,
    )
    print(str(module))


# Verify round-trip of ASM that contains unicode.
# Note that this does not test that the print path converts unicode properly
# because MLIR asm always normalizes it to the hex encoding.
# CHECK-LABEL: TEST: testRoundtripBinary
# CHECK: func private @roundtripUnicode()
# CHECK: foo = "\F0\9F\98\8A"
@run
def testRoundtripBinary():
    with Context():
        module = Module.parse(
            r"""
      func.func private @roundtripUnicode() attributes { foo = "😊" }
    """
        )
        binary_asm = module.operation.get_asm(binary=True)
        assert isinstance(binary_asm, bytes)
        module = Module.parse(binary_asm)
        print(module)


# Tests that module.operation works and correctly interns instances.
# CHECK-LABEL: TEST: testModuleOperation
@run
def testModuleOperation():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert ctx._get_live_module_count() == 1
    op1 = module.operation
    # CHECK: module @successfulParse
    print(op1)

    # Ensure that operations are the same on multiple calls.
    op2 = module.operation
    assert op1 is not op2
    assert op1 == op2

    # Test live operation clearing.
    op1 = module.operation
    op1 = None
    gc.collect()
    op1 = module.operation

    # Ensure that if module is de-referenced, the operations are still valid.
    module = None
    gc.collect()
    print(op1)

    # Collect and verify lifetime.
    op1 = None
    op2 = None
    gc.collect()
    assert ctx._get_live_module_count() == 0


# CHECK-LABEL: TEST: testModuleCapsule
@run
def testModuleCapsule():
    ctx = Context()
    module = Module.parse(r"""module @successfulParse {}""", ctx)
    assert ctx._get_live_module_count() == 1
    # CHECK: "mlir.ir.Module._CAPIPtr"
    module_capsule = module._CAPIPtr
    print(module_capsule)
    module_dup = Module._CAPICreate(module_capsule)
    assert module is module_dup
    assert module == module_dup
    assert module_dup.context is ctx
    # Gc and verify destructed.
    module = None
    module_capsule = None
    module_dup = None
    gc.collect()
    assert ctx._get_live_module_count() == 0