diff options
Diffstat (limited to 'mlir/test/python')
| -rw-r--r-- | mlir/test/python/dialects/math.py | 2 | ||||
| -rw-r--r-- | mlir/test/python/dialects/python_test.py | 38 | ||||
| -rw-r--r-- | mlir/test/python/dialects/shape.py | 5 | ||||
| -rw-r--r-- | mlir/test/python/ir/dialects.py | 4 | ||||
| -rw-r--r-- | mlir/test/python/python_test_ops.td | 24 |
5 files changed, 64 insertions, 9 deletions
diff --git a/mlir/test/python/dialects/math.py b/mlir/test/python/dialects/math.py index 73246e2e60ed..e3f8829b2778 100644 --- a/mlir/test/python/dialects/math.py +++ b/mlir/test/python/dialects/math.py @@ -16,7 +16,7 @@ def testMathOps(): with InsertionPoint(module.body): @builtin.FuncOp.from_py_func(F32Type.get()) def emit_sqrt(arg): - return mlir_math.SqrtOp(F32Type.get(), arg) + return mlir_math.SqrtOp(arg) # CHECK-LABEL: func @emit_sqrt( # CHECK-SAME: %[[ARG:.*]]: f32) { diff --git a/mlir/test/python/dialects/python_test.py b/mlir/test/python/dialects/python_test.py index 3d0600e331a5..2267b59cd4d7 100644 --- a/mlir/test/python/dialects/python_test.py +++ b/mlir/test/python/dialects/python_test.py @@ -137,8 +137,7 @@ def inferReturnTypes(): test.register_python_test_dialect(ctx) module = Module.create() with InsertionPoint(module.body): - op = test.InferResultsOp( - IntegerType.get_signless(32), IntegerType.get_signless(64)) + op = test.InferResultsOp() dummy = test.DummyOp() # CHECK: [Type(i32), Type(i64)] @@ -173,3 +172,38 @@ def inferReturnTypes(): pass else: assert False, "not expected dummy op class to implement the interface" + + +# CHECK-LABEL: TEST: resultTypesDefinedByTraits +@run +def resultTypesDefinedByTraits(): + with Context() as ctx, Location.unknown(ctx): + test.register_python_test_dialect(ctx) + module = Module.create() + with InsertionPoint(module.body): + inferred = test.InferResultsOp() + same = test.SameOperandAndResultTypeOp([inferred.results[0]]) + # CHECK-COUNT-2: i32 + print(same.one.type) + print(same.two.type) + + first_type_attr = test.FirstAttrDeriveTypeAttrOp( + inferred.results[1], TypeAttr.get(IndexType.get())) + # CHECK-COUNT-2: index + print(first_type_attr.one.type) + print(first_type_attr.two.type) + + first_attr = test.FirstAttrDeriveAttrOp( + FloatAttr.get(F32Type.get(), 3.14)) + # CHECK-COUNT-3: f32 + print(first_attr.one.type) + print(first_attr.two.type) + print(first_attr.three.type) + + implied = test.InferResultsImpliedOp() + # CHECK: i32 + print(implied.integer.type) + # CHECK: f64 + print(implied.flt.type) + # CHECK: index + print(implied.index.type) diff --git a/mlir/test/python/dialects/shape.py b/mlir/test/python/dialects/shape.py index 1772026cb9b9..7c1c5d6f1cf9 100644 --- a/mlir/test/python/dialects/shape.py +++ b/mlir/test/python/dialects/shape.py @@ -18,15 +18,12 @@ def testConstShape(): with Context() as ctx, Location.unknown(): module = Module.create() f32 = F32Type.get() - indexT = IndexType.get() with InsertionPoint(module.body): @builtin.FuncOp.from_py_func( RankedTensorType.get((12, -1), f32)) def const_shape_tensor(arg): - return shape.ConstShapeOp(RankedTensorType.get((2,), indexT), - DenseElementsAttr.get(np.array([10, 20]))) + return shape.ConstShapeOp(DenseElementsAttr.get(np.array([10, 20]))) # CHECK-LABEL: func @const_shape_tensor(%arg0: tensor<12x?xf32>) # CHECK: shape.const_shape [10, 20] : tensor<2xindex> print(module) - diff --git a/mlir/test/python/ir/dialects.py b/mlir/test/python/ir/dialects.py index 342d93bca6a4..05e9222c3e31 100644 --- a/mlir/test/python/ir/dialects.py +++ b/mlir/test/python/ir/dialects.py @@ -82,11 +82,11 @@ def testCustomOpView(): # Create via dialects context collection. input1 = createInput() input2 = createInput() - op1 = ctx.dialects.arith.AddFOp(input1.type, input1, input2) + op1 = ctx.dialects.arith.AddFOp(input1, input2) # Create via an import from mlir.dialects.arith import AddFOp - AddFOp(input1.type, input1, op1.result) + AddFOp(input1, op1.result) # CHECK: %[[INPUT0:.*]] = "pytest_dummy.intinput" # CHECK: %[[INPUT1:.*]] = "pytest_dummy.intinput" diff --git a/mlir/test/python/python_test_ops.td b/mlir/test/python/python_test_ops.td index 74c90a311f04..0f947e7e536b 100644 --- a/mlir/test/python/python_test_ops.td +++ b/mlir/test/python/python_test_ops.td @@ -52,4 +52,28 @@ def InferResultsOp : TestOp<"infer_results_op", [InferTypeOpInterface]> { }]; } +// If all result types are buildable, the InferTypeOpInterface is implied and is +// autogenerated by C++ ODS. +def InferResultsImpliedOp : TestOp<"infer_results_implied_op"> { + let results = (outs I32:$integer, F64:$flt, Index:$index); +} + +def SameOperandAndResultTypeOp : TestOp<"same_operand_and_result_type_op", + [SameOperandsAndResultType]> { + let arguments = (ins Variadic<AnyType>); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveTypeAttrOp : TestOp<"first_attr_derive_type_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyType:$input, TypeAttr:$type); + let results = (outs AnyType:$one, AnyType:$two); +} + +def FirstAttrDeriveAttrOp : TestOp<"first_attr_derive_attr_op", + [FirstAttrDerivedResultType]> { + let arguments = (ins AnyAttr:$iattr); + let results = (outs AnyType:$one, AnyType:$two, AnyType:$three); +} + #endif // PYTHON_TEST_OPS |
