# RUN: %PYTHON %s | FileCheck %s import gc from mlir.ir import * from mlir.dialects import arith, tensor, func, memref import mlir.extras.types as T def run(f): print("\nTEST:", f.__name__) f() gc.collect() assert Context._get_live_count() == 0 return f # CHECK-LABEL: TEST: testParsePrint @run def testParsePrint(): ctx = Context() t = Type.parse("i32", ctx) assert t.context is ctx ctx = None gc.collect() # CHECK: i32 print(str(t)) # CHECK: Type(i32) print(repr(t)) # CHECK-LABEL: TEST: testParseError @run def testParseError(): ctx = Context() try: t = Type.parse("BAD_TYPE_DOES_NOT_EXIST", ctx) except MLIRError as e: # CHECK: testParseError: < # CHECK: Unable to parse type: # CHECK: error: "BAD_TYPE_DOES_NOT_EXIST":1:1: expected non-function type # CHECK: > print(f"testParseError: <{e}>") else: print("Exception not produced") # CHECK-LABEL: TEST: testTypeEq @run def testTypeEq(): ctx = Context() t1 = Type.parse("i32", ctx) t2 = Type.parse("f32", ctx) t3 = Type.parse("i32", ctx) # CHECK: t1 == t1: True print("t1 == t1:", t1 == t1) # CHECK: t1 == t2: False print("t1 == t2:", t1 == t2) # CHECK: t1 == t3: True print("t1 == t3:", t1 == t3) # CHECK: t1 is None: False print("t1 is None:", t1 is None) # CHECK-LABEL: TEST: testTypeHash @run def testTypeHash(): ctx = Context() t1 = Type.parse("i32", ctx) t2 = Type.parse("f32", ctx) t3 = Type.parse("i32", ctx) # CHECK: hash(t1) == hash(t3): True print("hash(t1) == hash(t3):", t1.__hash__() == t3.__hash__()) s = set() s.add(t1) s.add(t2) s.add(t3) # CHECK: len(s): 2 print("len(s): ", len(s)) # CHECK-LABEL: TEST: testTypeCast @run def testTypeCast(): ctx = Context() t1 = Type.parse("i32", ctx) t2 = Type(t1) # CHECK: t1 == t2: True print("t1 == t2:", t1 == t2) # CHECK-LABEL: TEST: testTypeIsInstance @run def testTypeIsInstance(): ctx = Context() t1 = Type.parse("i32", ctx) t2 = Type.parse("f32", ctx) # CHECK: True print(IntegerType.isinstance(t1)) # CHECK: False print(F32Type.isinstance(t1)) # CHECK: False print(FloatType.isinstance(t1)) # CHECK: True print(F32Type.isinstance(t2)) # CHECK: True print(FloatType.isinstance(t2)) # CHECK-LABEL: TEST: testFloatTypeSubclasses @run def testFloatTypeSubclasses(): ctx = Context() # CHECK: True print(isinstance(Type.parse("f4E2M1FN", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f6E2M3FN", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f6E3M2FN", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E3M4", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E4M3", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E4M3FN", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E5M2", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E4M3FNUZ", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E4M3B11FNUZ", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E5M2FNUZ", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f8E8M0FNU", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f16", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("bf16", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f32", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("tf32", ctx), FloatType)) # CHECK: True print(isinstance(Type.parse("f64", ctx), FloatType)) # CHECK-LABEL: TEST: testTypeEqDoesNotRaise @run def testTypeEqDoesNotRaise(): ctx = Context() t1 = Type.parse("i32", ctx) not_a_type = "foo" # CHECK: False print(t1 == not_a_type) # CHECK: False print(t1 is None) # CHECK: True print(t1 is not None) # CHECK-LABEL: TEST: testTypeCapsule @run def testTypeCapsule(): with Context() as ctx: t1 = Type.parse("i32", ctx) # CHECK: mlir.ir.Type._CAPIPtr type_capsule = t1._CAPIPtr print(type_capsule) t2 = Type._CAPICreate(type_capsule) assert t2 == t1 assert t2.context is ctx # CHECK-LABEL: TEST: testStandardTypeCasts @run def testStandardTypeCasts(): ctx = Context() t1 = Type.parse("i32", ctx) tint = IntegerType(t1) tself = IntegerType(tint) # CHECK: Type(i32) print(repr(tint)) try: tillegal = IntegerType(Type.parse("f32", ctx)) except ValueError as e: # CHECK: ValueError: Cannot cast type to IntegerType (from Type(f32)) print("ValueError:", e) else: print("Exception not produced") # CHECK-LABEL: TEST: testIntegerType @run def testIntegerType(): with Context() as ctx: i32 = IntegerType(Type.parse("i32")) # CHECK: i32 width: 32 print("i32 width:", i32.width) # CHECK: i32 signless: True print("i32 signless:", i32.is_signless) # CHECK: i32 signed: False print("i32 signed:", i32.is_signed) # CHECK: i32 unsigned: False print("i32 unsigned:", i32.is_unsigned) s32 = IntegerType(Type.parse("si32")) # CHECK: s32 signless: False print("s32 signless:", s32.is_signless) # CHECK: s32 signed: True print("s32 signed:", s32.is_signed) # CHECK: s32 unsigned: False print("s32 unsigned:", s32.is_unsigned) u32 = IntegerType(Type.parse("ui32")) # CHECK: u32 signless: False print("u32 signless:", u32.is_signless) # CHECK: u32 signed: False print("u32 signed:", u32.is_signed) # CHECK: u32 unsigned: True print("u32 unsigned:", u32.is_unsigned) # CHECK: signless: i16 print("signless:", IntegerType.get_signless(16)) # CHECK: signed: si8 print("signed:", IntegerType.get_signed(8)) # CHECK: unsigned: ui64 print("unsigned:", IntegerType.get_unsigned(64)) # CHECK-LABEL: TEST: testIndexType @run def testIndexType(): with Context() as ctx: # CHECK: index type: index print("index type:", IndexType.get()) # CHECK-LABEL: TEST: testFloatType @run def testFloatType(): with Context(): # CHECK: float: f4E2M1FN print("float:", Float4E2M1FNType.get()) # CHECK: float: f6E2M3FN print("float:", Float6E2M3FNType.get()) # CHECK: float: f6E3M2FN print("float:", Float6E3M2FNType.get()) # CHECK: float: f8E3M4 print("float:", Float8E3M4Type.get()) # CHECK: float: f8E4M3 print("float:", Float8E4M3Type.get()) # CHECK: float: f8E4M3FN print("float:", Float8E4M3FNType.get()) # CHECK: float: f8E5M2 print("float:", Float8E5M2Type.get()) # CHECK: float: f8E5M2FNUZ print("float:", Float8E5M2FNUZType.get()) # CHECK: float: f8E4M3FNUZ print("float:", Float8E4M3FNUZType.get()) # CHECK: float: f8E4M3B11FNUZ print("float:", Float8E4M3B11FNUZType.get()) # CHECK: float: f8E8M0FNU print("float:", Float8E8M0FNUType.get()) # CHECK: float: bf16 print("float:", BF16Type.get()) # CHECK: float: f16 print("float:", F16Type.get()) # CHECK: float: tf32 print("float:", FloatTF32Type.get()) # CHECK: float: f32 print("float:", F32Type.get()) # CHECK: float: f64 f64 = F64Type.get() print("float:", f64) # CHECK: f64 width: 64 print("f64 width:", f64.width) # CHECK-LABEL: TEST: testNoneType @run def testNoneType(): with Context(): # CHECK: none type: none print("none type:", NoneType.get()) # CHECK-LABEL: TEST: testComplexType @run def testComplexType(): with Context() as ctx: complex_i32 = ComplexType(Type.parse("complex")) # CHECK: complex type element: i32 print("complex type element:", complex_i32.element_type) f32 = F32Type.get() # CHECK: complex type: complex print("complex type:", ComplexType.get(f32)) index = IndexType.get() try: complex_invalid = ComplexType.get(index) except ValueError as e: # CHECK: invalid 'Type(index)' and expected floating point or integer type. print(e) else: print("Exception not produced") # CHECK-LABEL: TEST: testConcreteShapedType # Shaped type is not a kind of builtin types, it is the base class for vectors, # memrefs and tensors, so this test case uses an instance of vector to test the # shaped type. The class hierarchy is preserved on the python side. @run def testConcreteShapedType(): with Context() as ctx: vector = VectorType(Type.parse("vector<2x3xf32>")) # CHECK: element type: f32 print("element type:", vector.element_type) # CHECK: whether the given shaped type is ranked: True print("whether the given shaped type is ranked:", vector.has_rank) # CHECK: rank: 2 print("rank:", vector.rank) # CHECK: whether the shaped type has a static shape: True print("whether the shaped type has a static shape:", vector.has_static_shape) # CHECK: whether the dim-th dimension is dynamic: False print("whether the dim-th dimension is dynamic:", vector.is_dynamic_dim(0)) # CHECK: dim size: 3 print("dim size:", vector.get_dim_size(1)) # CHECK: is_dynamic_size: False print("is_dynamic_size:", vector.is_dynamic_size(3)) # CHECK: is_static_size: True print("is_static_size:", vector.is_static_size(3)) # CHECK: is_dynamic_stride_or_offset: False print("is_dynamic_stride_or_offset:", vector.is_dynamic_stride_or_offset(1)) # CHECK: is_static_stride_or_offset: True print("is_static_stride_or_offset:", vector.is_static_stride_or_offset(1)) dynamic_size_val = vector.get_dynamic_size() dynamic_stride_val = vector.get_dynamic_stride_or_offset() # CHECK: is_dynamic_size_with_dynamic: True print("is_dynamic_size_with_dynamic:", vector.is_dynamic_size(dynamic_size_val)) # CHECK: is_static_size_with_dynamic: False print("is_static_size_with_dynamic:", vector.is_static_size(dynamic_size_val)) # CHECK: is_dynamic_stride_or_offset_with_dynamic: True print( "is_dynamic_stride_or_offset_with_dynamic:", vector.is_dynamic_stride_or_offset(dynamic_stride_val), ) # CHECK: is_static_stride_or_offset_with_dynamic: False print( "is_static_stride_or_offset_with_dynamic:", vector.is_static_stride_or_offset(dynamic_stride_val), ) # CHECK: isinstance(ShapedType): True print("isinstance(ShapedType):", isinstance(vector, ShapedType)) # CHECK-LABEL: TEST: testAbstractShapedType # Tests that ShapedType operates as an abstract base class of a concrete # shaped type (using vector as an example). @run def testAbstractShapedType(): ctx = Context() vector = ShapedType(Type.parse("vector<2x3xf32>", ctx)) # CHECK: element type: f32 print("element type:", vector.element_type) # CHECK-LABEL: TEST: testVectorType @run def testVectorType(): shape = [2, 3] with Context(): f32 = F32Type.get() # CHECK: unchecked vector type: vector<2x3xf32> print("unchecked vector type:", VectorType.get_unchecked(shape, f32)) with Context(), Location.unknown(): f32 = F32Type.get() # CHECK: checked vector type: vector<2x3xf32> print("checked vector type:", VectorType.get(shape, f32)) none = NoneType.get() try: VectorType.get(shape, none) except MLIRError as e: # CHECK: Invalid type: # CHECK: error: unknown: failed to verify 'elementType': VectorElementTypeInterface instance print(e) else: print("Exception not produced") scalable_1 = VectorType.get(shape, f32, scalable=[False, True]) scalable_2 = VectorType.get([2, 3, 4], f32, scalable=[True, False, True]) assert scalable_1.scalable assert scalable_2.scalable assert scalable_1.scalable_dims == [False, True] assert scalable_2.scalable_dims == [True, False, True] # CHECK: scalable 1: vector<2x[3]xf32> print("scalable 1: ", scalable_1) # CHECK: scalable 2: vector<[2]x3x[4]xf32> print("scalable 2: ", scalable_2) scalable_3 = VectorType.get(shape, f32, scalable_dims=[1]) scalable_4 = VectorType.get([2, 3, 4], f32, scalable_dims=[0, 2]) assert scalable_3 == scalable_1 assert scalable_4 == scalable_2 try: VectorType.get(shape, f32, scalable=[False, True, True]) except ValueError as e: # CHECK: Expected len(scalable) == len(shape). print(e) else: print("Exception not produced") try: VectorType.get(shape, f32, scalable=[False, True], scalable_dims=[1]) except ValueError as e: # CHECK: kwargs are mutually exclusive. print(e) else: print("Exception not produced") try: VectorType.get(shape, f32, scalable_dims=[42]) except ValueError as e: # CHECK: Scalable dimension index out of bounds. print(e) else: print("Exception not produced") # CHECK-LABEL: TEST: testRankedTensorType @run def testRankedTensorType(): with Context(), Location.unknown(): f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() # CHECK: ranked tensor type: tensor<2x3xf32> print("ranked tensor type:", RankedTensorType.get(shape, f32)) none = NoneType.get() try: tensor_invalid = RankedTensorType.get(shape, none) except MLIRError as e: # CHECK: Invalid type: # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") tensor = RankedTensorType.get(shape, f32, StringAttr.get("encoding")) assert tensor.shape == shape assert tensor.encoding.value == "encoding" # Encoding should be None. assert RankedTensorType.get(shape, f32).encoding is None # CHECK-LABEL: TEST: testUnrankedTensorType @run def testUnrankedTensorType(): with Context(), Location.unknown(): f32 = F32Type.get() loc = Location.unknown() unranked_tensor = UnrankedTensorType.get(f32) # CHECK: unranked tensor type: tensor<*xf32> print("unranked tensor type:", unranked_tensor) try: invalid_rank = unranked_tensor.rank except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") try: invalid_is_dynamic_dim = unranked_tensor.is_dynamic_dim(0) except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") try: invalid_get_dim_size = unranked_tensor.get_dim_size(1) except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") none = NoneType.get() try: tensor_invalid = UnrankedTensorType.get(none) except MLIRError as e: # CHECK: Invalid type: # CHECK: error: unknown: invalid tensor element type: 'none' print(e) else: print("Exception not produced") # CHECK-LABEL: TEST: testMemRefType @run def testMemRefType(): with Context(), Location.unknown(): f32 = F32Type.get() shape = [2, 3] loc = Location.unknown() memref_f32 = MemRefType.get(shape, f32, memory_space=Attribute.parse("2")) # CHECK: memref type: memref<2x3xf32, 2> print("memref type:", memref_f32) # CHECK: memref layout: AffineMapAttr(affine_map<(d0, d1) -> (d0, d1)>) print("memref layout:", repr(memref_f32.layout)) # CHECK: memref affine map: (d0, d1) -> (d0, d1) print("memref affine map:", memref_f32.affine_map) # CHECK: memory space: IntegerAttr(2 : i64) print("memory space:", repr(memref_f32.memory_space)) layout = AffineMapAttr.get(AffineMap.get_permutation([1, 0])) memref_layout = MemRefType.get(shape, f32, layout=layout) # CHECK: memref type: memref<2x3xf32, affine_map<(d0, d1) -> (d1, d0)>> print("memref type:", memref_layout) # CHECK: memref layout: affine_map<(d0, d1) -> (d1, d0)> print("memref layout:", memref_layout.layout) # CHECK: memref affine map: (d0, d1) -> (d1, d0) print("memref affine map:", memref_layout.affine_map) # CHECK: memory space: None print("memory space:", memref_layout.memory_space) none = NoneType.get() try: memref_invalid = MemRefType.get(shape, none) except MLIRError as e: # CHECK: Invalid type: # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") assert memref_f32.shape == shape # CHECK-LABEL: TEST: testUnrankedMemRefType @run def testUnrankedMemRefType(): with Context(), Location.unknown(): f32 = F32Type.get() loc = Location.unknown() unranked_memref = UnrankedMemRefType.get(f32, Attribute.parse("2")) # CHECK: unranked memref type: memref<*xf32, 2> print("unranked memref type:", unranked_memref) # CHECK: memory space: IntegerAttr(2 : i64) print("memory space:", repr(unranked_memref.memory_space)) try: invalid_rank = unranked_memref.rank except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") try: invalid_is_dynamic_dim = unranked_memref.is_dynamic_dim(0) except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") try: invalid_get_dim_size = unranked_memref.get_dim_size(1) except ValueError as e: # CHECK: calling this method requires that the type has a rank. print(e) else: print("Exception not produced") none = NoneType.get() try: memref_invalid = UnrankedMemRefType.get(none, Attribute.parse("2")) except MLIRError as e: # CHECK: Invalid type: # CHECK: error: unknown: invalid memref element type print(e) else: print("Exception not produced") # CHECK-LABEL: TEST: testTupleType @run def testTupleType(): with Context() as ctx: i32 = IntegerType(Type.parse("i32")) f32 = F32Type.get() vector = VectorType(Type.parse("vector<2x3xf32>")) l = [i32, f32, vector] tuple_type = TupleType.get_tuple(l) # CHECK: tuple type: tuple> print("tuple type:", tuple_type) # CHECK: number of types: 3 print("number of types:", tuple_type.num_types) # CHECK: pos-th type in the tuple type: f32 print("pos-th type in the tuple type:", tuple_type.get_type(1)) # CHECK-LABEL: TEST: testFunctionType @run def testFunctionType(): with Context() as ctx: input_types = [IntegerType.get_signless(32), IntegerType.get_signless(16)] result_types = [IndexType.get()] func = FunctionType.get(input_types, result_types) # CHECK: INPUTS: [IntegerType(i32), IntegerType(i16)] print("INPUTS:", func.inputs) # CHECK: RESULTS: [IndexType(index)] print("RESULTS:", func.results) # CHECK-LABEL: TEST: testOpaqueType @run def testOpaqueType(): with Context() as ctx: ctx.allow_unregistered_dialects = True opaque = OpaqueType.get("dialect", "type") # CHECK: opaque type: !dialect.type print("opaque type:", opaque) # CHECK: dialect namespace: dialect print("dialect namespace:", opaque.dialect_namespace) # CHECK: data: type print("data:", opaque.data) # CHECK-LABEL: TEST: testShapedTypeConstants # Tests that ShapedType exposes magic value constants. @run def testShapedTypeConstants(): # CHECK: print(type(ShapedType.get_dynamic_size())) # CHECK: print(type(ShapedType.get_dynamic_stride_or_offset())) # CHECK-LABEL: TEST: testTypeIDs @run def testTypeIDs(): with Context(), Location.unknown(): f32 = F32Type.get() types = [ (IntegerType, IntegerType.get_signless(16)), (IndexType, IndexType.get()), (Float4E2M1FNType, Float4E2M1FNType.get()), (Float6E2M3FNType, Float6E2M3FNType.get()), (Float6E3M2FNType, Float6E3M2FNType.get()), (Float8E3M4Type, Float8E3M4Type.get()), (Float8E4M3Type, Float8E4M3Type.get()), (Float8E4M3FNType, Float8E4M3FNType.get()), (Float8E5M2Type, Float8E5M2Type.get()), (Float8E4M3FNUZType, Float8E4M3FNUZType.get()), (Float8E4M3B11FNUZType, Float8E4M3B11FNUZType.get()), (Float8E5M2FNUZType, Float8E5M2FNUZType.get()), (Float8E8M0FNUType, Float8E8M0FNUType.get()), (BF16Type, BF16Type.get()), (F16Type, F16Type.get()), (F32Type, F32Type.get()), (FloatTF32Type, FloatTF32Type.get()), (F64Type, F64Type.get()), (NoneType, NoneType.get()), (ComplexType, ComplexType.get(f32)), (VectorType, VectorType.get([2, 3], f32)), (RankedTensorType, RankedTensorType.get([2, 3], f32)), (UnrankedTensorType, UnrankedTensorType.get(f32)), (MemRefType, MemRefType.get([2, 3], f32)), (UnrankedMemRefType, UnrankedMemRefType.get(f32, Attribute.parse("2"))), (TupleType, TupleType.get_tuple([f32])), (FunctionType, FunctionType.get([], [])), (OpaqueType, OpaqueType.get("tensor", "bob")), ] # CHECK: IntegerType(i16) # CHECK: IndexType(index) # CHECK: Float4E2M1FNType(f4E2M1FN) # CHECK: Float6E2M3FNType(f6E2M3FN) # CHECK: Float6E3M2FNType(f6E3M2FN) # CHECK: Float8E3M4Type(f8E3M4) # CHECK: Float8E4M3Type(f8E4M3) # CHECK: Float8E4M3FNType(f8E4M3FN) # CHECK: Float8E5M2Type(f8E5M2) # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) # CHECK: Float8E8M0FNUType(f8E8M0FNU) # CHECK: BF16Type(bf16) # CHECK: F16Type(f16) # CHECK: F32Type(f32) # CHECK: FloatTF32Type(tf32) # CHECK: F64Type(f64) # CHECK: NoneType(none) # CHECK: ComplexType(complex) # CHECK: VectorType(vector<2x3xf32>) # CHECK: RankedTensorType(tensor<2x3xf32>) # CHECK: UnrankedTensorType(tensor<*xf32>) # CHECK: MemRefType(memref<2x3xf32>) # CHECK: UnrankedMemRefType(memref<*xf32, 2>) # CHECK: TupleType(tuple) # CHECK: FunctionType(() -> ()) # CHECK: OpaqueType(!tensor.bob) for _, t in types: print(repr(t)) # Test getTypeIdFunction agrees with # mlirTypeGetTypeID(self) for an instance. # CHECK: all equal for t1, t2 in types: tid1, tid2 = t1.static_typeid, Type(t2).typeid assert tid1 == tid2 and hash(tid1) == hash( tid2 ), f"expected hash and value equality {t1} {t2}" else: print("all equal") # Test that storing PyTypeID in python dicts # works as expected. typeid_dict = dict(types) assert len(typeid_dict) # CHECK: all equal for t1, t2 in typeid_dict.items(): assert t1.static_typeid == t2.typeid and hash(t1.static_typeid) == hash( t2.typeid ), f"expected hash and value equality {t1} {t2}" else: print("all equal") # CHECK: ShapedType has no typeid. try: print(ShapedType.static_typeid) except AttributeError as e: print(e) vector_type = Type.parse("vector<2x3xf32>") # CHECK: True print(ShapedType(vector_type).typeid == vector_type.typeid) # CHECK-LABEL: TEST: testConcreteTypesRoundTrip @run def testConcreteTypesRoundTrip(): with Context() as ctx, Location.unknown(): ctx.allow_unregistered_dialects = True def print_downcasted(typ): downcasted = Type(typ).maybe_downcast() print(type(downcasted).__name__) print(repr(downcasted)) # CHECK: F16Type # CHECK: F16Type(f16) print_downcasted(F16Type.get()) # CHECK: F32Type # CHECK: F32Type(f32) print_downcasted(F32Type.get()) # CHECK: FloatTF32Type # CHECK: FloatTF32Type(tf32) print_downcasted(FloatTF32Type.get()) # CHECK: F64Type # CHECK: F64Type(f64) print_downcasted(F64Type.get()) # CHECK: Float4E2M1FNType # CHECK: Float4E2M1FNType(f4E2M1FN) print_downcasted(Float4E2M1FNType.get()) # CHECK: Float6E2M3FNType # CHECK: Float6E2M3FNType(f6E2M3FN) print_downcasted(Float6E2M3FNType.get()) # CHECK: Float6E3M2FNType # CHECK: Float6E3M2FNType(f6E3M2FN) print_downcasted(Float6E3M2FNType.get()) # CHECK: Float8E3M4Type # CHECK: Float8E3M4Type(f8E3M4) print_downcasted(Float8E3M4Type.get()) # CHECK: Float8E4M3B11FNUZType # CHECK: Float8E4M3B11FNUZType(f8E4M3B11FNUZ) print_downcasted(Float8E4M3B11FNUZType.get()) # CHECK: Float8E4M3Type # CHECK: Float8E4M3Type(f8E4M3) print_downcasted(Float8E4M3Type.get()) # CHECK: Float8E4M3FNType # CHECK: Float8E4M3FNType(f8E4M3FN) print_downcasted(Float8E4M3FNType.get()) # CHECK: Float8E4M3FNUZType # CHECK: Float8E4M3FNUZType(f8E4M3FNUZ) print_downcasted(Float8E4M3FNUZType.get()) # CHECK: Float8E5M2Type # CHECK: Float8E5M2Type(f8E5M2) print_downcasted(Float8E5M2Type.get()) # CHECK: Float8E5M2FNUZType # CHECK: Float8E5M2FNUZType(f8E5M2FNUZ) print_downcasted(Float8E5M2FNUZType.get()) # CHECK: Float8E8M0FNUType # CHECK: Float8E8M0FNUType(f8E8M0FNU) print_downcasted(Float8E8M0FNUType.get()) # CHECK: BF16Type # CHECK: BF16Type(bf16) print_downcasted(BF16Type.get()) # CHECK: IndexType # CHECK: IndexType(index) print_downcasted(IndexType.get()) # CHECK: IntegerType # CHECK: IntegerType(i32) print_downcasted(IntegerType.get_signless(32)) f32 = F32Type.get() ranked_tensor = tensor.EmptyOp([10, 10], f32).result # CHECK: RankedTensorType print(type(ranked_tensor.type).__name__) # CHECK: RankedTensorType(tensor<10x10xf32>) print(repr(ranked_tensor.type)) cf32 = ComplexType.get(f32) # CHECK: ComplexType print(type(cf32).__name__) # CHECK: ComplexType(complex) print(repr(cf32)) ranked_tensor = tensor.EmptyOp([10, 10], f32).result # CHECK: RankedTensorType print(type(ranked_tensor.type).__name__) # CHECK: RankedTensorType(tensor<10x10xf32>) print(repr(ranked_tensor.type)) vector = VectorType.get([10, 10], f32) tuple_type = TupleType.get_tuple([f32, vector]) # CHECK: TupleType print(type(tuple_type).__name__) # CHECK: TupleType(tuple>) print(repr(tuple_type)) # CHECK: F32Type(f32) print(repr(tuple_type.get_type(0))) # CHECK: VectorType(vector<10x10xf32>) print(repr(tuple_type.get_type(1))) index_type = IndexType.get() @func.FuncOp.from_py_func() def default_builder(): c0 = arith.ConstantOp(f32, 0.0) unranked_tensor_type = UnrankedTensorType.get(f32) unranked_tensor = tensor.FromElementsOp(unranked_tensor_type, [c0]).result # CHECK: UnrankedTensorType print(type(unranked_tensor.type).__name__) # CHECK: UnrankedTensorType(tensor<*xf32>) print(repr(unranked_tensor.type)) c10 = arith.ConstantOp(index_type, 10) memref_f32_t = MemRefType.get([10, 10], f32) memref_f32 = memref.AllocOp(memref_f32_t, [c10, c10], []).result # CHECK: MemRefType print(type(memref_f32.type).__name__) # CHECK: MemRefType(memref<10x10xf32>) print(repr(memref_f32.type)) unranked_memref_t = UnrankedMemRefType.get(f32, Attribute.parse("2")) memref_f32 = memref.AllocOp(unranked_memref_t, [c10, c10], []).result # CHECK: UnrankedMemRefType print(type(memref_f32.type).__name__) # CHECK: UnrankedMemRefType(memref<*xf32, 2>) print(repr(memref_f32.type)) tuple_type = Operation.parse( f'"test.make_tuple"() : () -> tuple' ).result # CHECK: TupleType print(type(tuple_type.type).__name__) # CHECK: TupleType(tuple) print(repr(tuple_type.type)) return c0, c10 # CHECK-LABEL: TEST: testCustomTypeTypeCaster # This tests being able to materialize a type from a dialect *and* have # the implemented type caster called without explicitly importing the dialect. # I.e., we get a transform.OperationType without explicitly importing the transform dialect. @run def testCustomTypeTypeCaster(): with Context() as ctx, Location.unknown(): t = Type.parse('!transform.op<"foo.bar">', Context()) # CHECK: !transform.op<"foo.bar"> print(t) # CHECK: OperationType(!transform.op<"foo.bar">) print(repr(t)) # CHECK-LABEL: TEST: testTypeWrappers @run def testTypeWrappers(): def stride(strides, offset=0): return StridedLayoutAttr.get(offset, strides) with Context(), Location.unknown(): ia = T.i(5) sia = T.si(6) uia = T.ui(7) assert repr(ia) == "IntegerType(i5)" assert repr(sia) == "IntegerType(si6)" assert repr(uia) == "IntegerType(ui7)" assert T.i(16) == T.i16() assert T.si(16) == T.si16() assert T.ui(16) == T.ui16() c1 = T.complex(T.f16()) c2 = T.complex(T.i32()) assert repr(c1) == "ComplexType(complex)" assert repr(c2) == "ComplexType(complex)" vec_1 = T.vector(2, 3, T.f32()) vec_2 = T.vector(2, 3, 4, T.f32()) assert repr(vec_1) == "VectorType(vector<2x3xf32>)" assert repr(vec_2) == "VectorType(vector<2x3x4xf32>)" m1 = T.memref(2, 3, 4, T.f64()) assert repr(m1) == "MemRefType(memref<2x3x4xf64>)" m2 = T.memref(2, 3, 4, T.f64(), memory_space=1) assert repr(m2) == "MemRefType(memref<2x3x4xf64, 1>)" m3 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13])) assert repr(m3) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13]>, 1>)" m4 = T.memref(2, 3, 4, T.f64(), memory_space=1, layout=stride([5, 7, 13], 42)) assert ( repr(m4) == "MemRefType(memref<2x3x4xf64, strided<[5, 7, 13], offset: 42>, 1>)" ) S = ShapedType.get_dynamic_size() t1 = T.tensor(S, 3, S, T.f64()) assert repr(t1) == "RankedTensorType(tensor)" ut1 = T.tensor(T.f64()) assert repr(ut1) == "UnrankedTensorType(tensor<*xf64>)" t2 = T.tensor(S, 3, S, element_type=T.f64()) assert repr(t2) == "RankedTensorType(tensor)" ut2 = T.tensor(element_type=T.f64()) assert repr(ut2) == "UnrankedTensorType(tensor<*xf64>)" t3 = T.tensor(S, 3, S, T.f64(), encoding="encoding") assert repr(t3) == 'RankedTensorType(tensor)' v = T.vector(3, 3, 3, T.f64()) assert repr(v) == "VectorType(vector<3x3x3xf64>)" m5 = T.memref(S, 3, S, T.f64()) assert repr(m5) == "MemRefType(memref)" um1 = T.memref(T.f64()) assert repr(um1) == "UnrankedMemRefType(memref<*xf64>)" m6 = T.memref(S, 3, S, element_type=T.f64()) assert repr(m6) == "MemRefType(memref)" um2 = T.memref(element_type=T.f64()) assert repr(um2) == "UnrankedMemRefType(memref<*xf64>)" m7 = T.memref(S, 3, S, T.f64()) assert repr(m7) == "MemRefType(memref)" um3 = T.memref(T.f64()) assert repr(um3) == "UnrankedMemRefType(memref<*xf64>)" scalable_1 = T.vector(2, 3, T.f32(), scalable=[False, True]) scalable_2 = T.vector(2, 3, 4, T.f32(), scalable=[True, False, True]) assert repr(scalable_1) == "VectorType(vector<2x[3]xf32>)" assert repr(scalable_2) == "VectorType(vector<[2]x3x[4]xf32>)" scalable_3 = T.vector(2, 3, T.f32(), scalable_dims=[1]) scalable_4 = T.vector(2, 3, 4, T.f32(), scalable_dims=[0, 2]) assert scalable_3 == scalable_1 assert scalable_4 == scalable_2 opaq = T.opaque("scf", "placeholder") assert repr(opaq) == "OpaqueType(!scf.placeholder)" tup1 = T.tuple(T.i16(), T.i32(), T.i64()) tup2 = T.tuple(T.f16(), T.f32(), T.f64()) assert repr(tup1) == "TupleType(tuple)" assert repr(tup2) == "TupleType(tuple)" func = T.function( inputs=(T.i16(), T.i32(), T.i64()), results=(T.f16(), T.f32(), T.f64()) ) assert repr(func) == "FunctionType((i16, i32, i64) -> (f16, f32, f64))"