summaryrefslogtreecommitdiff
path: root/mlir/test/python
diff options
context:
space:
mode:
Diffstat (limited to 'mlir/test/python')
-rw-r--r--mlir/test/python/dialects/math.py2
-rw-r--r--mlir/test/python/dialects/python_test.py38
-rw-r--r--mlir/test/python/dialects/shape.py5
-rw-r--r--mlir/test/python/ir/dialects.py4
-rw-r--r--mlir/test/python/python_test_ops.td24
5 files changed, 64 insertions, 9 deletions
diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py
index 73246e2e60ed..e3f8829b2778 100644
--- a/mlir/test/python/dialects/math.py
+++ b/mlir/test/python/dialects/math.py
@@ -16,7 +16,7 @@ def testMathOps():
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(F32Type.get())
def emit_sqrt(arg):
- return mlir_math.SqrtOp(F32Type.get(), arg)
+ return mlir_math.SqrtOp(arg)
# CHECK-LABEL: func @emit_sqrt(
# CHECK-SAME: %[[ARG:.*]]: f32) {
diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py
index 3d0600e331a5..2267b59cd4d7 100644
--- a/mlir/test/python/dialects/python_test.py
+++ b/mlir/test/python/dialects/python_test.py
@@ -137,8 +137,7 @@ def inferReturnTypes():
test.register_python_test_dialect(ctx)
module = Module.create()
with InsertionPoint(module.body):
- op = test.InferResultsOp(
- IntegerType.get_signless(32), IntegerType.get_signless(64))
+ op = test.InferResultsOp()
dummy = test.DummyOp()
# CHECK: [Type(i32), Type(i64)]
@@ -173,3 +172,38 @@ def inferReturnTypes():
pass
else:
assert False, "not expected dummy op class to implement the interface"
+
+
+# CHECK-LABEL: TEST: resultTypesDefinedByTraits
+@run
+def resultTypesDefinedByTraits():
+ with Context() as ctx, Location.unknown(ctx):
+ test.register_python_test_dialect(ctx)
+ module = Module.create()
+ with InsertionPoint(module.body):
+ inferred = test.InferResultsOp()
+ same = test.SameOperandAndResultTypeOp([inferred.results[0]])
+ # CHECK-COUNT-2: i32
+ print(same.one.type)
+ print(same.two.type)
+
+ first_type_attr = test.FirstAttrDeriveTypeAttrOp(
+ inferred.results[1], TypeAttr.get(IndexType.get()))
+ # CHECK-COUNT-2: index
+ print(first_type_attr.one.type)
+ print(first_type_attr.two.type)
+
+ first_attr = test.FirstAttrDeriveAttrOp(
+ FloatAttr.get(F32Type.get(), 3.14))
+ # CHECK-COUNT-3: f32
+ print(first_attr.one.type)
+ print(first_attr.two.type)
+ print(first_attr.three.type)
+
+ implied = test.InferResultsImpliedOp()
+ # CHECK: i32
+ print(implied.integer.type)
+ # CHECK: f64
+ print(implied.flt.type)
+ # CHECK: index
+ print(implied.index.type)
diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py
index 1772026cb9b9..7c1c5d6f1cf9 100644
--- a/mlir/test/python/dialects/shape.py
+++ b/mlir/test/python/dialects/shape.py
@@ -18,15 +18,12 @@ def testConstShape():
with Context() as ctx, Location.unknown():
module = Module.create()
f32 = F32Type.get()
- indexT = IndexType.get()
with InsertionPoint(module.body):
@builtin.FuncOp.from_py_func(
RankedTensorType.get((12, -1), f32))
def const_shape_tensor(arg):
- return shape.ConstShapeOp(RankedTensorType.get((2,), indexT),
- DenseElementsAttr.get(np.array([10, 20])))
+ return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20])))
# CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>)
# CHECK: shape.const_shape [10, 20] : tensor<2xindex>
print(module)
-
diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py
index 342d93bca6a4..05e9222c3e31 100644
--- a/mlir/test/python/ir/dialects.py
+++ b/mlir/test/python/ir/dialects.py
@@ -82,11 +82,11 @@ def testCustomOpView():
# Create via dialects context collection.
input1 = createInput()
input2 = createInput()
- op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2)
+ op1 = ctx.dialects.arith.AddFOp(input1, input2)
# Create via an import
from mlir.dialects.arith import AddFOp
- AddFOp(input1.type, input1, op1.result)
+ AddFOp(input1, op1.result)
# CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput"
# CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput"
diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td
index 74c90a311f04..0f947e7e536b 100644
--- a/mlir/test/python/python_test_ops.td
+++ b/mlir/test/python/python_test_ops.td
@@ -52,4 +52,28 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> {
}];
}
+// If all result types are buildable, the InferTypeOpInterface is implied and is
+// autogenerated by C++ ODS.
+def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> {
+ let results = (outs I32:$integer, F64:$flt, Index:$index);
+}
+
+def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op",
+ [SameOperandsAndResultType]> {
+ let arguments = (ins Variadic<AnyType>);
+ let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op",
+ [FirstAttrDerivedResultType]> {
+ let arguments = (ins AnyType:$input, TypeAttr:$type);
+ let results = (outs AnyType:$one, AnyType:$two);
+}
+
+def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op",
+ [FirstAttrDerivedResultType]> {
+ let arguments = (ins AnyAttr:$iattr);
+ let results = (outs AnyType:$one, AnyType:$two, AnyType:$three);
+}
+
#endif // PYTHON_TEST_OPS